* add class UnpersistedEventContext * modify create new client event to create unpersistedeventcontexts * persist event contexts after creation * fix tests to persist unpersisted event contexts * cleanup * misc lints + cleanup * changelog + fix comments * lints * fix batch insertion? * reduce redundant calculation * add unpersisted event classes * rework compute_event_context, split into function that returns unpersisted event context and then persists it * use calculate_context_info to create unpersisted event contexts * update typing * $%#^&* * black * fix comments and consolidate classes, use attr.s for class * requested changes * lint * requested changes * requested changes * refactor to be stupidly explicit * clearer renaming and flow * make partial state non-optional * update docstrings --------- Co-authored-by: Erik Johnston <erik@matrix.org>tags/v1.78.0rc1
@@ -0,0 +1 @@ | |||
Add a class UnpersistedEventContext to allow for the batching up of storing state groups. |
@@ -11,6 +11,7 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from abc import ABC, abstractmethod | |||
from typing import TYPE_CHECKING, List, Optional, Tuple | |||
import attr | |||
@@ -26,8 +27,51 @@ if TYPE_CHECKING: | |||
from synapse.types.state import StateFilter | |||
class UnpersistedEventContextBase(ABC): | |||
""" | |||
This is a base class for EventContext and UnpersistedEventContext, objects which | |||
hold information relevant to storing an associated event. Note that an | |||
UnpersistedEventContexts must be converted into an EventContext before it is | |||
suitable to send to the db with its associated event. | |||
Attributes: | |||
_storage: storage controllers for interfacing with the database | |||
app_service: If the associated event is being sent by a (local) application service, that | |||
app service. | |||
""" | |||
def __init__(self, storage_controller: "StorageControllers"): | |||
self._storage: "StorageControllers" = storage_controller | |||
self.app_service: Optional[ApplicationService] = None | |||
@abstractmethod | |||
async def persist( | |||
self, | |||
event: EventBase, | |||
) -> "EventContext": | |||
""" | |||
A method to convert an UnpersistedEventContext to an EventContext, suitable for | |||
sending to the database with the associated event. | |||
""" | |||
pass | |||
@abstractmethod | |||
async def get_prev_state_ids( | |||
self, state_filter: Optional["StateFilter"] = None | |||
) -> StateMap[str]: | |||
""" | |||
Gets the room state at the event (ie not including the event if the event is a | |||
state event). | |||
Args: | |||
state_filter: specifies the type of state event to fetch from DB, example: | |||
EventTypes.JoinRules | |||
""" | |||
pass | |||
@attr.s(slots=True, auto_attribs=True) | |||
class EventContext: | |||
class EventContext(UnpersistedEventContextBase): | |||
""" | |||
Holds information relevant to persisting an event | |||
@@ -77,9 +121,6 @@ class EventContext: | |||
delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` | |||
and ``state_group``. | |||
app_service: If this event is being sent by a (local) application service, that | |||
app service. | |||
partial_state: if True, we may be storing this event with a temporary, | |||
incomplete state. | |||
""" | |||
@@ -122,6 +163,9 @@ class EventContext: | |||
"""Return an EventContext instance suitable for persisting an outlier event""" | |||
return EventContext(storage=storage) | |||
async def persist(self, event: EventBase) -> "EventContext": | |||
return self | |||
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: | |||
"""Converts self to a type that can be serialized as JSON, and then | |||
deserialized by `deserialize` | |||
@@ -254,6 +298,128 @@ class EventContext: | |||
) | |||
@attr.s(slots=True, auto_attribs=True) | |||
class UnpersistedEventContext(UnpersistedEventContextBase): | |||
""" | |||
The event context holds information about the state groups for an event. It is important | |||
to remember that an event technically has two state groups: the state group before the | |||
event, and the state group after the event. If the event is not a state event, the state | |||
group will not change (ie the state group before the event will be the same as the state | |||
group after the event), but if it is a state event the state group before the event | |||
will differ from the state group after the event. | |||
This is a version of an EventContext before the new state group (if any) has been | |||
computed and stored. It contains information about the state before the event (which | |||
also may be the information after the event, if the event is not a state event). The | |||
UnpersistedEventContext must be converted into an EventContext by calling the method | |||
'persist' on it before it is suitable to be sent to the DB for processing. | |||
state_group_after_event: | |||
The state group after the event. This will always be None until it is persisted. | |||
If the event is not a state event, this will be the same as | |||
state_group_before_event. | |||
state_group_before_event: | |||
The ID of the state group representing the state of the room before this event. | |||
state_delta_due_to_event: | |||
If the event is a state event, then this is the delta of the state between | |||
`state_group` and `state_group_before_event` | |||
prev_group_for_state_group_before_event: | |||
If it is known, ``state_group_before_event``'s previous state group. | |||
delta_ids_to_state_group_before_event: | |||
If ``prev_group_for_state_group_before_event`` is not None, the state delta | |||
between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``. | |||
partial_state: | |||
Whether the event has partial state. | |||
state_map_before_event: | |||
A map of the state before the event, i.e. the state at `state_group_before_event` | |||
""" | |||
_storage: "StorageControllers" | |||
state_group_before_event: Optional[int] | |||
state_group_after_event: Optional[int] | |||
state_delta_due_to_event: Optional[dict] | |||
prev_group_for_state_group_before_event: Optional[int] | |||
delta_ids_to_state_group_before_event: Optional[StateMap[str]] | |||
partial_state: bool | |||
state_map_before_event: Optional[StateMap[str]] = None | |||
async def get_prev_state_ids( | |||
self, state_filter: Optional["StateFilter"] = None | |||
) -> StateMap[str]: | |||
""" | |||
Gets the room state map, excluding this event. | |||
Args: | |||
state_filter: specifies the type of state event to fetch from DB | |||
Returns: | |||
Maps a (type, state_key) to the event ID of the state event matching | |||
this tuple. | |||
""" | |||
if self.state_map_before_event: | |||
return self.state_map_before_event | |||
assert self.state_group_before_event is not None | |||
return await self._storage.state.get_state_ids_for_group( | |||
self.state_group_before_event, state_filter | |||
) | |||
async def persist(self, event: EventBase) -> EventContext: | |||
""" | |||
Creates a full `EventContext` for the event, persisting any referenced state that | |||
has not yet been persisted. | |||
Args: | |||
event: event that the EventContext is associated with. | |||
Returns: An EventContext suitable for sending to the database with the event | |||
for persisting | |||
""" | |||
assert self.partial_state is not None | |||
# If we have a full set of state for before the event but don't have a state | |||
# group for that state, we need to get one | |||
if self.state_group_before_event is None: | |||
assert self.state_map_before_event | |||
state_group_before_event = await self._storage.state.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=self.prev_group_for_state_group_before_event, | |||
delta_ids=self.delta_ids_to_state_group_before_event, | |||
current_state_ids=self.state_map_before_event, | |||
) | |||
self.state_group_before_event = state_group_before_event | |||
# if the event isn't a state event the state group doesn't change | |||
if not self.state_delta_due_to_event: | |||
state_group_after_event = self.state_group_before_event | |||
# otherwise if it is a state event we need to get a state group for it | |||
else: | |||
state_group_after_event = await self._storage.state.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=self.state_group_before_event, | |||
delta_ids=self.state_delta_due_to_event, | |||
current_state_ids=None, | |||
) | |||
return EventContext.with_state( | |||
storage=self._storage, | |||
state_group=state_group_after_event, | |||
state_group_before_event=self.state_group_before_event, | |||
state_delta_due_to_event=self.state_delta_due_to_event, | |||
partial_state=self.partial_state, | |||
prev_group=self.state_group_before_event, | |||
delta_ids=self.state_delta_due_to_event, | |||
) | |||
def _encode_state_dict( | |||
state_dict: Optional[StateMap[str]], | |||
) -> Optional[List[Tuple[str, str, str]]]: | |||
@@ -18,7 +18,7 @@ from twisted.internet.defer import CancelledError | |||
from synapse.api.errors import ModuleFailedException, SynapseError | |||
from synapse.events import EventBase | |||
from synapse.events.snapshot import EventContext | |||
from synapse.events.snapshot import UnpersistedEventContextBase | |||
from synapse.storage.roommember import ProfileInfo | |||
from synapse.types import Requester, StateMap | |||
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable | |||
@@ -231,7 +231,9 @@ class ThirdPartyEventRules: | |||
self._on_threepid_bind_callbacks.append(on_threepid_bind) | |||
async def check_event_allowed( | |||
self, event: EventBase, context: EventContext | |||
self, | |||
event: EventBase, | |||
context: UnpersistedEventContextBase, | |||
) -> Tuple[bool, Optional[dict]]: | |||
"""Check if a provided event should be allowed in the given context. | |||
@@ -56,7 +56,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion | |||
from synapse.crypto.event_signing import compute_event_signature | |||
from synapse.event_auth import validate_event_for_room_version | |||
from synapse.events import EventBase | |||
from synapse.events.snapshot import EventContext | |||
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase | |||
from synapse.events.validator import EventValidator | |||
from synapse.federation.federation_client import InvalidResponseError | |||
from synapse.http.servlet import assert_params_in_dict | |||
@@ -990,7 +990,10 @@ class FederationHandler: | |||
) | |||
try: | |||
event, context = await self.event_creation_handler.create_new_client_event( | |||
( | |||
event, | |||
unpersisted_context, | |||
) = await self.event_creation_handler.create_new_client_event( | |||
builder=builder | |||
) | |||
except SynapseError as e: | |||
@@ -998,7 +1001,9 @@ class FederationHandler: | |||
raise | |||
# Ensure the user can even join the room. | |||
await self._federation_event_handler.check_join_restrictions(context, event) | |||
await self._federation_event_handler.check_join_restrictions( | |||
unpersisted_context, event | |||
) | |||
# The remote hasn't signed it yet, obviously. We'll do the full checks | |||
# when we get the event back in `on_send_join_request` | |||
@@ -1178,7 +1183,7 @@ class FederationHandler: | |||
}, | |||
) | |||
event, context = await self.event_creation_handler.create_new_client_event( | |||
event, _ = await self.event_creation_handler.create_new_client_event( | |||
builder=builder | |||
) | |||
@@ -1228,12 +1233,13 @@ class FederationHandler: | |||
}, | |||
) | |||
event, context = await self.event_creation_handler.create_new_client_event( | |||
builder=builder | |||
) | |||
( | |||
event, | |||
unpersisted_context, | |||
) = await self.event_creation_handler.create_new_client_event(builder=builder) | |||
event_allowed, _ = await self.third_party_event_rules.check_event_allowed( | |||
event, context | |||
event, unpersisted_context | |||
) | |||
if not event_allowed: | |||
logger.warning("Creation of knock %s forbidden by third-party rules", event) | |||
@@ -1406,15 +1412,20 @@ class FederationHandler: | |||
try: | |||
( | |||
event, | |||
context, | |||
unpersisted_context, | |||
) = await self.event_creation_handler.create_new_client_event( | |||
builder=builder | |||
) | |||
event, context = await self.add_display_name_to_third_party_invite( | |||
room_version_obj, event_dict, event, context | |||
( | |||
event, | |||
unpersisted_context, | |||
) = await self.add_display_name_to_third_party_invite( | |||
room_version_obj, event_dict, event, unpersisted_context | |||
) | |||
context = await unpersisted_context.persist(event) | |||
EventValidator().validate_new(event, self.config) | |||
# We need to tell the transaction queue to send this out, even | |||
@@ -1483,14 +1494,19 @@ class FederationHandler: | |||
try: | |||
( | |||
event, | |||
context, | |||
unpersisted_context, | |||
) = await self.event_creation_handler.create_new_client_event( | |||
builder=builder | |||
) | |||
event, context = await self.add_display_name_to_third_party_invite( | |||
room_version_obj, event_dict, event, context | |||
( | |||
event, | |||
unpersisted_context, | |||
) = await self.add_display_name_to_third_party_invite( | |||
room_version_obj, event_dict, event, unpersisted_context | |||
) | |||
context = await unpersisted_context.persist(event) | |||
try: | |||
validate_event_for_room_version(event) | |||
await self._event_auth_handler.check_auth_rules_from_context(event) | |||
@@ -1522,8 +1538,8 @@ class FederationHandler: | |||
room_version_obj: RoomVersion, | |||
event_dict: JsonDict, | |||
event: EventBase, | |||
context: EventContext, | |||
) -> Tuple[EventBase, EventContext]: | |||
context: UnpersistedEventContextBase, | |||
) -> Tuple[EventBase, UnpersistedEventContextBase]: | |||
key = ( | |||
EventTypes.ThirdPartyInvite, | |||
event.content["third_party_invite"]["signed"]["token"], | |||
@@ -1557,11 +1573,14 @@ class FederationHandler: | |||
room_version_obj, event_dict | |||
) | |||
EventValidator().validate_builder(builder) | |||
event, context = await self.event_creation_handler.create_new_client_event( | |||
builder=builder | |||
) | |||
( | |||
event, | |||
unpersisted_context, | |||
) = await self.event_creation_handler.create_new_client_event(builder=builder) | |||
EventValidator().validate_new(event, self.config) | |||
return event, context | |||
return event, unpersisted_context | |||
async def _check_signature(self, event: EventBase, context: EventContext) -> None: | |||
""" | |||
@@ -58,7 +58,7 @@ from synapse.event_auth import ( | |||
validate_event_for_room_version, | |||
) | |||
from synapse.events import EventBase | |||
from synapse.events.snapshot import EventContext | |||
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase | |||
from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo | |||
from synapse.logging.context import nested_logging_context | |||
from synapse.logging.opentracing import ( | |||
@@ -426,7 +426,9 @@ class FederationEventHandler: | |||
return event, context | |||
async def check_join_restrictions( | |||
self, context: EventContext, event: EventBase | |||
self, | |||
context: UnpersistedEventContextBase, | |||
event: EventBase, | |||
) -> None: | |||
"""Check that restrictions in restricted join rules are matched | |||
@@ -48,7 +48,7 @@ from synapse.api.urls import ConsentURIBuilder | |||
from synapse.event_auth import validate_event_for_room_version | |||
from synapse.events import EventBase, relation_from_event | |||
from synapse.events.builder import EventBuilder | |||
from synapse.events.snapshot import EventContext | |||
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase | |||
from synapse.events.utils import maybe_upsert_event_field | |||
from synapse.events.validator import EventValidator | |||
from synapse.handlers.directory import DirectoryHandler | |||
@@ -708,7 +708,7 @@ class EventCreationHandler: | |||
builder.internal_metadata.historical = historical | |||
event, context = await self.create_new_client_event( | |||
event, unpersisted_context = await self.create_new_client_event( | |||
builder=builder, | |||
requester=requester, | |||
allow_no_prev_events=allow_no_prev_events, | |||
@@ -721,6 +721,8 @@ class EventCreationHandler: | |||
current_state_group=current_state_group, | |||
) | |||
context = await unpersisted_context.persist(event) | |||
# In an ideal world we wouldn't need the second part of this condition. However, | |||
# this behaviour isn't spec'd yet, meaning we should be able to deactivate this | |||
# behaviour. Another reason is that this code is also evaluated each time a new | |||
@@ -1083,13 +1085,14 @@ class EventCreationHandler: | |||
state_map: Optional[StateMap[str]] = None, | |||
for_batch: bool = False, | |||
current_state_group: Optional[int] = None, | |||
) -> Tuple[EventBase, EventContext]: | |||
) -> Tuple[EventBase, UnpersistedEventContextBase]: | |||
"""Create a new event for a local client. If bool for_batch is true, will | |||
create an event using the prev_event_ids, and will create an event context for | |||
the event using the parameters state_map and current_state_group, thus these parameters | |||
must be provided in this case if for_batch is True. The subsequently created event | |||
and context are suitable for being batched up and bulk persisted to the database | |||
with other similarly created events. | |||
with other similarly created events. Note that this returns an UnpersistedEventContext, | |||
which must be converted to an EventContext before it can be sent to the DB. | |||
Args: | |||
builder: | |||
@@ -1131,7 +1134,7 @@ class EventCreationHandler: | |||
batch persisting | |||
Returns: | |||
Tuple of created event, context | |||
Tuple of created event, UnpersistedEventContext | |||
""" | |||
# Strip down the state_event_ids to only what we need to auth the event. | |||
# For example, we don't need extra m.room.member that don't match event.sender | |||
@@ -1192,9 +1195,16 @@ class EventCreationHandler: | |||
event = await builder.build( | |||
prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth | |||
) | |||
context = await self.state.compute_event_context_for_batched( | |||
event, state_map, current_state_group | |||
context: UnpersistedEventContextBase = ( | |||
await self.state.calculate_context_info( | |||
event, | |||
state_ids_before_event=state_map, | |||
partial_state=False, | |||
state_group_before_event=current_state_group, | |||
) | |||
) | |||
else: | |||
event = await builder.build( | |||
prev_event_ids=prev_event_ids, | |||
@@ -1244,16 +1254,17 @@ class EventCreationHandler: | |||
state_map_for_event[(data.event_type, data.state_key)] = state_id | |||
context = await self.state.compute_event_context( | |||
# TODO(faster_joins): check how MSC2716 works and whether we can have | |||
# partial state here | |||
# https://github.com/matrix-org/synapse/issues/13003 | |||
context = await self.state.calculate_context_info( | |||
event, | |||
state_ids_before_event=state_map_for_event, | |||
# TODO(faster_joins): check how MSC2716 works and whether we can have | |||
# partial state here | |||
# https://github.com/matrix-org/synapse/issues/13003 | |||
partial_state=False, | |||
) | |||
else: | |||
context = await self.state.compute_event_context(event) | |||
context = await self.state.calculate_context_info(event) | |||
if requester: | |||
context.app_service = requester.app_service | |||
@@ -2082,9 +2093,9 @@ class EventCreationHandler: | |||
async def _rebuild_event_after_third_party_rules( | |||
self, third_party_result: dict, original_event: EventBase | |||
) -> Tuple[EventBase, EventContext]: | |||
) -> Tuple[EventBase, UnpersistedEventContextBase]: | |||
# the third_party_event_rules want to replace the event. | |||
# we do some basic checks, and then return the replacement event and context. | |||
# we do some basic checks, and then return the replacement event. | |||
# Construct a new EventBuilder and validate it, which helps with the | |||
# rest of these checks. | |||
@@ -2138,5 +2149,6 @@ class EventCreationHandler: | |||
# we rebuild the event context, to be on the safe side. If nothing else, | |||
# delta_ids might need an update. | |||
context = await self.state.compute_event_context(event) | |||
context = await self.state.calculate_context_info(event) | |||
return event, context |
@@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram | |||
from synapse.api.constants import EventTypes | |||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions | |||
from synapse.events import EventBase | |||
from synapse.events.snapshot import EventContext | |||
from synapse.events.snapshot import ( | |||
EventContext, | |||
UnpersistedEventContext, | |||
UnpersistedEventContextBase, | |||
) | |||
from synapse.logging.context import ContextResourceUsage | |||
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet | |||
from synapse.state import v1, v2 | |||
@@ -262,31 +266,31 @@ class StateHandler: | |||
state = await entry.get_state(self._state_storage_controller, StateFilter.all()) | |||
return await self.store.get_joined_hosts(room_id, state, entry) | |||
async def compute_event_context( | |||
async def calculate_context_info( | |||
self, | |||
event: EventBase, | |||
state_ids_before_event: Optional[StateMap[str]] = None, | |||
partial_state: Optional[bool] = None, | |||
) -> EventContext: | |||
"""Build an EventContext structure for a non-outlier event. | |||
(for an outlier, call EventContext.for_outlier directly) | |||
This works out what the current state should be for the event, and | |||
generates a new state group if necessary. | |||
Args: | |||
event: | |||
state_ids_before_event: The event ids of the state before the event if | |||
it can't be calculated from existing events. This is normally | |||
only specified when receiving an event from federation where we | |||
don't have the prev events, e.g. when backfilling. | |||
partial_state: | |||
`True` if `state_ids_before_event` is partial and omits non-critical | |||
membership events. | |||
`False` if `state_ids_before_event` is the full state. | |||
`None` when `state_ids_before_event` is not provided. In this case, the | |||
flag will be calculated based on `event`'s prev events. | |||
state_group_before_event: Optional[int] = None, | |||
) -> UnpersistedEventContextBase: | |||
""" | |||
Calulates the contents of an unpersisted event context, other than the current | |||
state group (which is either provided or calculated when the event context is persisted) | |||
state_ids_before_event: | |||
The event ids of the full state before the event if | |||
it can't be calculated from existing events. This is normally | |||
only specified when receiving an event from federation where we | |||
don't have the prev events, e.g. when backfilling or when the event | |||
is being created for batch persisting. | |||
partial_state: | |||
`True` if `state_ids_before_event` is partial and omits non-critical | |||
membership events. | |||
`False` if `state_ids_before_event` is the full state. | |||
`None` when `state_ids_before_event` is not provided. In this case, the | |||
flag will be calculated based on `event`'s prev events. | |||
state_group_before_event: | |||
the current state group at the time of event, if known | |||
Returns: | |||
The event context. | |||
@@ -294,7 +298,6 @@ class StateHandler: | |||
RuntimeError if `state_ids_before_event` is not provided and one or more | |||
prev events are missing or outliers. | |||
""" | |||
assert not event.internal_metadata.is_outlier() | |||
# | |||
@@ -306,17 +309,6 @@ class StateHandler: | |||
state_group_before_event_prev_group = None | |||
deltas_to_state_group_before_event = None | |||
# .. though we need to get a state group for it. | |||
state_group_before_event = ( | |||
await self._state_storage_controller.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=None, | |||
delta_ids=None, | |||
current_state_ids=state_ids_before_event, | |||
) | |||
) | |||
# the partial_state flag must be provided | |||
assert partial_state is not None | |||
else: | |||
@@ -345,6 +337,7 @@ class StateHandler: | |||
logger.debug("calling resolve_state_groups from compute_event_context") | |||
# we've already taken into account partial state, so no need to wait for | |||
# complete state here. | |||
entry = await self.resolve_state_groups_for_events( | |||
event.room_id, | |||
event.prev_event_ids(), | |||
@@ -383,18 +376,19 @@ class StateHandler: | |||
# | |||
if not event.is_state(): | |||
return EventContext.with_state( | |||
return UnpersistedEventContext( | |||
storage=self._storage_controllers, | |||
state_group_before_event=state_group_before_event, | |||
state_group=state_group_before_event, | |||
state_group_after_event=state_group_before_event, | |||
state_delta_due_to_event={}, | |||
prev_group=state_group_before_event_prev_group, | |||
delta_ids=deltas_to_state_group_before_event, | |||
prev_group_for_state_group_before_event=state_group_before_event_prev_group, | |||
delta_ids_to_state_group_before_event=deltas_to_state_group_before_event, | |||
partial_state=partial_state, | |||
state_map_before_event=state_ids_before_event, | |||
) | |||
# | |||
# otherwise, we'll need to create a new state group for after the event | |||
# otherwise, we'll need to set up creating a new state group for after the event | |||
# | |||
key = (event.type, event.state_key) | |||
@@ -412,88 +406,60 @@ class StateHandler: | |||
delta_ids = {key: event.event_id} | |||
state_group_after_event = ( | |||
await self._state_storage_controller.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=state_group_before_event, | |||
delta_ids=delta_ids, | |||
current_state_ids=None, | |||
) | |||
) | |||
return EventContext.with_state( | |||
return UnpersistedEventContext( | |||
storage=self._storage_controllers, | |||
state_group=state_group_after_event, | |||
state_group_before_event=state_group_before_event, | |||
state_group_after_event=None, | |||
state_delta_due_to_event=delta_ids, | |||
prev_group=state_group_before_event, | |||
delta_ids=delta_ids, | |||
prev_group_for_state_group_before_event=state_group_before_event_prev_group, | |||
delta_ids_to_state_group_before_event=deltas_to_state_group_before_event, | |||
partial_state=partial_state, | |||
state_map_before_event=state_ids_before_event, | |||
) | |||
async def compute_event_context_for_batched( | |||
async def compute_event_context( | |||
self, | |||
event: EventBase, | |||
state_ids_before_event: StateMap[str], | |||
current_state_group: int, | |||
state_ids_before_event: Optional[StateMap[str]] = None, | |||
partial_state: Optional[bool] = None, | |||
) -> EventContext: | |||
""" | |||
Generate an event context for an event that has not yet been persisted to the | |||
database. Intended for use with events that are created to be persisted in a batch. | |||
Args: | |||
event: the event the context is being computed for | |||
state_ids_before_event: a state map consisting of the state ids of the events | |||
created prior to this event. | |||
current_state_group: the current state group before the event. | |||
""" | |||
state_group_before_event_prev_group = None | |||
deltas_to_state_group_before_event = None | |||
state_group_before_event = current_state_group | |||
# if the event is not state, we are set | |||
if not event.is_state(): | |||
return EventContext.with_state( | |||
storage=self._storage_controllers, | |||
state_group_before_event=state_group_before_event, | |||
state_group=state_group_before_event, | |||
state_delta_due_to_event={}, | |||
prev_group=state_group_before_event_prev_group, | |||
delta_ids=deltas_to_state_group_before_event, | |||
partial_state=False, | |||
) | |||
"""Build an EventContext structure for a non-outlier event. | |||
# otherwise, we'll need to create a new state group for after the event | |||
key = (event.type, event.state_key) | |||
(for an outlier, call EventContext.for_outlier directly) | |||
if state_ids_before_event is not None: | |||
replaces = state_ids_before_event.get(key) | |||
This works out what the current state should be for the event, and | |||
generates a new state group if necessary. | |||
if replaces and replaces != event.event_id: | |||
event.unsigned["replaces_state"] = replaces | |||
Args: | |||
event: | |||
state_ids_before_event: The event ids of the state before the event if | |||
it can't be calculated from existing events. This is normally | |||
only specified when receiving an event from federation where we | |||
don't have the prev events, e.g. when backfilling. | |||
partial_state: | |||
`True` if `state_ids_before_event` is partial and omits non-critical | |||
membership events. | |||
`False` if `state_ids_before_event` is the full state. | |||
`None` when `state_ids_before_event` is not provided. In this case, the | |||
flag will be calculated based on `event`'s prev events. | |||
entry: | |||
A state cache entry for the resolved state across the prev events. We may | |||
have already calculated this, so if it's available pass it in | |||
Returns: | |||
The event context. | |||
delta_ids = {key: event.event_id} | |||
Raises: | |||
RuntimeError if `state_ids_before_event` is not provided and one or more | |||
prev events are missing or outliers. | |||
""" | |||
state_group_after_event = ( | |||
await self._state_storage_controller.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=state_group_before_event, | |||
delta_ids=delta_ids, | |||
current_state_ids=None, | |||
) | |||
unpersisted_context = await self.calculate_context_info( | |||
event=event, | |||
state_ids_before_event=state_ids_before_event, | |||
partial_state=partial_state, | |||
) | |||
return EventContext.with_state( | |||
storage=self._storage_controllers, | |||
state_group=state_group_after_event, | |||
state_group_before_event=state_group_before_event, | |||
state_delta_due_to_event=delta_ids, | |||
prev_group=state_group_before_event, | |||
delta_ids=delta_ids, | |||
partial_state=False, | |||
) | |||
return await unpersisted_context.persist(event) | |||
@measure_func() | |||
async def resolve_state_groups_for_events( | |||
@@ -949,10 +949,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
event, context = self.get_success( | |||
event, unpersisted_context = self.get_success( | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
context = self.get_success(unpersisted_context.persist(event)) | |||
self.get_success( | |||
self.hs.get_storage_controllers().persistence.persist_event(event, context) | |||
) | |||
@@ -2934,10 +2934,12 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
event, context = self.get_success( | |||
event, unpersisted_context = self.get_success( | |||
event_creation_handler.create_new_client_event(builder) | |||
) | |||
context = self.get_success(unpersisted_context.persist(event)) | |||
self.get_success(storage_controllers.persistence.persist_event(event, context)) | |||
# Now get rooms | |||
@@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
event, context = self.get_success( | |||
event, unpersisted_context = self.get_success( | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
context = self.get_success(unpersisted_context.persist(event)) | |||
self.get_success(self._persistence.persist_event(event, context)) | |||
return event | |||
@@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
event, context = self.get_success( | |||
event, unpersisted_context = self.get_success( | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
context = self.get_success(unpersisted_context.persist(event)) | |||
self.get_success(self._persistence.persist_event(event, context)) | |||
return event | |||
@@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
event, context = self.get_success( | |||
event, unpersisted_context = self.get_success( | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
context = self.get_success(unpersisted_context.persist(event)) | |||
self.get_success(self._persistence.persist_event(event, context)) | |||
return event | |||
@@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
def internal_metadata(self) -> _EventInternalMetadata: | |||
return self._base_builder.internal_metadata | |||
event_1, context_1 = self.get_success( | |||
event_1, unpersisted_context_1 = self.get_success( | |||
self.event_creation_handler.create_new_client_event( | |||
cast( | |||
EventBuilder, | |||
@@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
context_1 = self.get_success(unpersisted_context_1.persist(event_1)) | |||
self.get_success(self._persistence.persist_event(event_1, context_1)) | |||
event_2, context_2 = self.get_success( | |||
event_2, unpersisted_context_2 = self.get_success( | |||
self.event_creation_handler.create_new_client_event( | |||
cast( | |||
EventBuilder, | |||
@@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
) | |||
context_2 = self.get_success(unpersisted_context_2.persist(event_2)) | |||
self.get_success(self._persistence.persist_event(event_2, context_2)) | |||
# fetch one of the redactions | |||
@@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
redaction_event, context = self.get_success( | |||
redaction_event, unpersisted_context = self.get_success( | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
context = self.get_success(unpersisted_context.persist(redaction_event)) | |||
self.get_success(self._persistence.persist_event(redaction_event, context)) | |||
# Now lets jump to the future where we have censored the redaction event | |||
@@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase): | |||
}, | |||
) | |||
event, context = self.get_success( | |||
event, unpersisted_context = self.get_success( | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
context = self.get_success(unpersisted_context.persist(event)) | |||
assert self.storage.persistence is not None | |||
self.get_success(self.storage.persistence.persist_event(event, context)) | |||
@@ -92,8 +92,13 @@ async def create_event( | |||
builder = hs.get_event_builder_factory().for_room_version( | |||
KNOWN_ROOM_VERSIONS[room_version], kwargs | |||
) | |||
event, context = await hs.get_event_creation_handler().create_new_client_event( | |||
( | |||
event, | |||
unpersisted_context, | |||
) = await hs.get_event_creation_handler().create_new_client_event( | |||
builder, prev_event_ids=prev_event_ids | |||
) | |||
context = await unpersisted_context.persist(event) | |||
return event, context |
@@ -175,9 +175,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
event, context = self.get_success( | |||
event, unpersisted_context = self.get_success( | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
context = self.get_success(unpersisted_context.persist(event)) | |||
self.get_success( | |||
self._storage_controllers.persistence.persist_event(event, context) | |||
) | |||
@@ -202,9 +203,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
event, context = self.get_success( | |||
event, unpersisted_context = self.get_success( | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
context = self.get_success(unpersisted_context.persist(event)) | |||
self.get_success( | |||
self._storage_controllers.persistence.persist_event(event, context) | |||
@@ -226,9 +228,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
event, context = self.get_success( | |||
event, unpersisted_context = self.get_success( | |||
self.event_creation_handler.create_new_client_event(builder) | |||
) | |||
context = self.get_success(unpersisted_context.persist(event)) | |||
self.get_success( | |||
self._storage_controllers.persistence.persist_event(event, context) | |||
@@ -335,6 +335,9 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None: | |||
}, | |||
) | |||
event, context = await event_creation_handler.create_new_client_event(builder) | |||
event, unpersisted_context = await event_creation_handler.create_new_client_event( | |||
builder | |||
) | |||
context = await unpersisted_context.persist(event) | |||
await persistence_store.persist_event(event, context) |