* Split state group persist into seperate storage func * Add per database engine code for state group id gen * Move store_state_group to StateReadStore This allows other workers to use it, and so resolve state. * Hook up store_state_group * Fix tests * Rename _store_mult_state_groups_txn * Rename StateGroupReadStore * Remove redundant _have_persisted_state_group_txn * Update comments * Comment compute_event_context * Set start val for state_group_id_seq ... otherwise we try to recreate old state groups * Update comments * Don't store state for outliers * Update comment * Update docstring as state groups are intstags/v0.27.0-rc1
@@ -25,7 +25,9 @@ class EventContext(object): | |||
The current state map excluding the current event. | |||
(type, state_key) -> event_id | |||
state_group (int): state group id | |||
state_group (int|None): state group id, if the state has been stored | |||
as a state group. This is usually only None if e.g. the event is | |||
an outlier. | |||
rejected (bool|str): A rejection reason if the event was rejected, else | |||
False | |||
@@ -1831,8 +1831,8 @@ class FederationHandler(BaseHandler): | |||
current_state = set(e.event_id for e in auth_events.values()) | |||
different_auth = event_auth_events - current_state | |||
self._update_context_for_auth_events( | |||
context, auth_events, event_key, | |||
yield self._update_context_for_auth_events( | |||
event, context, auth_events, event_key, | |||
) | |||
if different_auth and not event.internal_metadata.is_outlier(): | |||
@@ -1913,8 +1913,8 @@ class FederationHandler(BaseHandler): | |||
# 4. Look at rejects and their proofs. | |||
# TODO. | |||
self._update_context_for_auth_events( | |||
context, auth_events, event_key, | |||
yield self._update_context_for_auth_events( | |||
event, context, auth_events, event_key, | |||
) | |||
try: | |||
@@ -1923,11 +1923,15 @@ class FederationHandler(BaseHandler): | |||
logger.warn("Failed auth resolution for %r because %s", event, e) | |||
raise e | |||
def _update_context_for_auth_events(self, context, auth_events, | |||
@defer.inlineCallbacks | |||
def _update_context_for_auth_events(self, event, context, auth_events, | |||
event_key): | |||
"""Update the state_ids in an event context after auth event resolution | |||
"""Update the state_ids in an event context after auth event resolution, | |||
storing the changes as a new state group. | |||
Args: | |||
event (Event): The event we're handling the context for | |||
context (synapse.events.snapshot.EventContext): event context | |||
to be updated | |||
@@ -1950,7 +1954,13 @@ class FederationHandler(BaseHandler): | |||
context.prev_state_ids.update({ | |||
k: a.event_id for k, a in auth_events.iteritems() | |||
}) | |||
context.state_group = self.store.get_next_state_group() | |||
context.state_group = yield self.store.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=context.prev_group, | |||
delta_ids=context.delta_ids, | |||
current_state_ids=context.current_state_ids, | |||
) | |||
@defer.inlineCallbacks | |||
def construct_auth_difference(self, local_auth, remote_auth): | |||
@@ -19,7 +19,7 @@ from synapse.storage import DataStore | |||
from synapse.storage.event_federation import EventFederationStore | |||
from synapse.storage.event_push_actions import EventPushActionsStore | |||
from synapse.storage.roommember import RoomMemberStore | |||
from synapse.storage.state import StateGroupReadStore | |||
from synapse.storage.state import StateGroupWorkerStore | |||
from synapse.storage.stream import StreamStore | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
from ._base import BaseSlavedStore | |||
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) | |||
# the method descriptor on the DataStore and chuck them into our class. | |||
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore): | |||
class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore): | |||
def __init__(self, db_conn, hs): | |||
super(SlavedEventStore, self).__init__(db_conn, hs) | |||
@@ -183,8 +183,15 @@ class StateHandler(object): | |||
def compute_event_context(self, event, old_state=None): | |||
"""Build an EventContext structure for the event. | |||
This works out what the current state should be for the event, and | |||
generates a new state group if necessary. | |||
Args: | |||
event (synapse.events.EventBase): | |||
old_state (dict|None): The state at the event if it can't be | |||
calculated from existing events. This is normally only specified | |||
when receiving an event from federation where we don't have the | |||
prev events for, e.g. when backfilling. | |||
Returns: | |||
synapse.events.snapshot.EventContext: | |||
""" | |||
@@ -208,15 +215,22 @@ class StateHandler(object): | |||
context.current_state_ids = {} | |||
context.prev_state_ids = {} | |||
context.prev_state_events = [] | |||
context.state_group = self.store.get_next_state_group() | |||
# We don't store state for outliers, so we don't generate a state | |||
# froup for it. | |||
context.state_group = None | |||
defer.returnValue(context) | |||
if old_state: | |||
# We already have the state, so we don't need to calculate it. | |||
# Let's just correctly fill out the context and create a | |||
# new state group for it. | |||
context = EventContext() | |||
context.prev_state_ids = { | |||
(s.type, s.state_key): s.event_id for s in old_state | |||
} | |||
context.state_group = self.store.get_next_state_group() | |||
if event.is_state(): | |||
key = (event.type, event.state_key) | |||
@@ -229,6 +243,14 @@ class StateHandler(object): | |||
else: | |||
context.current_state_ids = context.prev_state_ids | |||
context.state_group = yield self.store.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=None, | |||
delta_ids=None, | |||
current_state_ids=context.current_state_ids, | |||
) | |||
context.prev_state_events = [] | |||
defer.returnValue(context) | |||
@@ -242,7 +264,8 @@ class StateHandler(object): | |||
context = EventContext() | |||
context.prev_state_ids = curr_state | |||
if event.is_state(): | |||
context.state_group = self.store.get_next_state_group() | |||
# If this is a state event then we need to create a new state | |||
# group for the state after this event. | |||
key = (event.type, event.state_key) | |||
if key in context.prev_state_ids: | |||
@@ -253,23 +276,42 @@ class StateHandler(object): | |||
context.current_state_ids[key] = event.event_id | |||
if entry.state_group: | |||
# If the state at the event has a state group assigned then | |||
# we can use that as the prev group | |||
context.prev_group = entry.state_group | |||
context.delta_ids = { | |||
key: event.event_id | |||
} | |||
elif entry.prev_group: | |||
# If the state at the event only has a prev group, then we can | |||
# use that as a prev group too. | |||
context.prev_group = entry.prev_group | |||
context.delta_ids = dict(entry.delta_ids) | |||
context.delta_ids[key] = event.event_id | |||
context.state_group = yield self.store.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=context.prev_group, | |||
delta_ids=context.delta_ids, | |||
current_state_ids=context.current_state_ids, | |||
) | |||
else: | |||
context.current_state_ids = context.prev_state_ids | |||
context.prev_group = entry.prev_group | |||
context.delta_ids = entry.delta_ids | |||
if entry.state_group is None: | |||
entry.state_group = self.store.get_next_state_group() | |||
entry.state_group = yield self.store.store_state_group( | |||
event.event_id, | |||
event.room_id, | |||
prev_group=entry.prev_group, | |||
delta_ids=entry.delta_ids, | |||
current_state_ids=context.current_state_ids, | |||
) | |||
entry.state_id = entry.state_group | |||
context.state_group = entry.state_group | |||
context.current_state_ids = context.prev_state_ids | |||
context.prev_group = entry.prev_group | |||
context.delta_ids = entry.delta_ids | |||
context.prev_state_events = [] | |||
defer.returnValue(context) | |||
@@ -124,7 +124,6 @@ class DataStore(RoomMemberStore, RoomStore, | |||
) | |||
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") | |||
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") | |||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") | |||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") | |||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") | |||
@@ -62,3 +62,9 @@ class PostgresEngine(object): | |||
def lock_table(self, txn, table): | |||
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) | |||
def get_next_state_group_id(self, txn): | |||
"""Returns an int that can be used as a new state_group ID | |||
""" | |||
txn.execute("SELECT nextval('state_group_id_seq')") | |||
return txn.fetchone()[0] |
@@ -16,6 +16,7 @@ | |||
from synapse.storage.prepare_database import prepare_database | |||
import struct | |||
import threading | |||
class Sqlite3Engine(object): | |||
@@ -24,6 +25,11 @@ class Sqlite3Engine(object): | |||
def __init__(self, database_module, database_config): | |||
self.module = database_module | |||
# The current max state_group, or None if we haven't looked | |||
# in the DB yet. | |||
self._current_state_group_id = None | |||
self._current_state_group_id_lock = threading.Lock() | |||
def check_database(self, txn): | |||
pass | |||
@@ -43,6 +49,19 @@ class Sqlite3Engine(object): | |||
def lock_table(self, txn, table): | |||
return | |||
def get_next_state_group_id(self, txn): | |||
"""Returns an int that can be used as a new state_group ID | |||
""" | |||
# We do application locking here since if we're using sqlite then | |||
# we are a single process synapse. | |||
with self._current_state_group_id_lock: | |||
if self._current_state_group_id is None: | |||
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") | |||
self._current_state_group_id = txn.fetchone()[0] | |||
self._current_state_group_id += 1 | |||
return self._current_state_group_id | |||
# Following functions taken from: https://github.com/coleifer/peewee | |||
@@ -755,9 +755,8 @@ class EventsStore(SQLBaseStore): | |||
events_and_contexts=events_and_contexts, | |||
) | |||
# Insert into the state_groups, state_groups_state, and | |||
# event_to_state_groups tables. | |||
self._store_mult_state_groups_txn(txn, events_and_contexts) | |||
# Insert into event_to_state_groups. | |||
self._store_event_state_mappings_txn(txn, events_and_contexts) | |||
# _store_rejected_events_txn filters out any events which were | |||
# rejected, and returns the filtered list. | |||
@@ -992,10 +991,9 @@ class EventsStore(SQLBaseStore): | |||
# an outlier in the database. We now have some state at that | |||
# so we need to update the state_groups table with that state. | |||
# insert into the state_group, state_groups_state and | |||
# event_to_state_groups tables. | |||
# insert into event_to_state_groups. | |||
try: | |||
self._store_mult_state_groups_txn(txn, ((event, context),)) | |||
self._store_event_state_mappings_txn(txn, ((event, context),)) | |||
except Exception: | |||
logger.exception("") | |||
raise | |||
@@ -0,0 +1,37 @@ | |||
# Copyright 2018 New Vector Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.storage.engines import PostgresEngine | |||
def run_create(cur, database_engine, *args, **kwargs): | |||
if isinstance(database_engine, PostgresEngine): | |||
# if we already have some state groups, we want to start making new | |||
# ones with a higher id. | |||
cur.execute("SELECT max(id) FROM state_groups") | |||
row = cur.fetchone() | |||
if row[0] is None: | |||
start_val = 1 | |||
else: | |||
start_val = row[0] + 1 | |||
cur.execute( | |||
"CREATE SEQUENCE state_group_id_seq START WITH %s", | |||
(start_val, ), | |||
) | |||
def run_upgrade(*args, **kwargs): | |||
pass |
@@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt | |||
return len(self.delta_ids) if self.delta_ids else 0 | |||
class StateGroupReadStore(SQLBaseStore): | |||
"""The read-only parts of StateGroupStore | |||
None of these functions write to the state tables, so are suitable for | |||
including in the SlavedStores. | |||
class StateGroupWorkerStore(SQLBaseStore): | |||
"""The parts of StateGroupStore that can be called from workers. | |||
""" | |||
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" | |||
@@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore): | |||
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" | |||
def __init__(self, db_conn, hs): | |||
super(StateGroupReadStore, self).__init__(db_conn, hs) | |||
super(StateGroupWorkerStore, self).__init__(db_conn, hs) | |||
self._state_group_cache = DictionaryCache( | |||
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR | |||
@@ -549,116 +546,66 @@ class StateGroupReadStore(SQLBaseStore): | |||
defer.returnValue(results) | |||
def store_state_group(self, event_id, room_id, prev_group, delta_ids, | |||
current_state_ids): | |||
"""Store a new set of state, returning a newly assigned state group. | |||
class StateStore(StateGroupReadStore, BackgroundUpdateStore): | |||
""" Keeps track of the state at a given event. | |||
This is done by the concept of `state groups`. Every event is a assigned | |||
a state group (identified by an arbitrary string), which references a | |||
collection of state events. The current state of an event is then the | |||
collection of state events referenced by the event's state group. | |||
Hence, every change in the current state causes a new state group to be | |||
generated. However, if no change happens (e.g., if we get a message event | |||
with only one parent it inherits the state group from its parent.) | |||
There are three tables: | |||
* `state_groups`: Stores group name, first event with in the group and | |||
room id. | |||
* `event_to_state_groups`: Maps events to state groups. | |||
* `state_groups_state`: Maps state group to state events. | |||
""" | |||
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" | |||
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" | |||
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" | |||
def __init__(self, db_conn, hs): | |||
super(StateStore, self).__init__(db_conn, hs) | |||
self.register_background_update_handler( | |||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, | |||
self._background_deduplicate_state, | |||
) | |||
self.register_background_update_handler( | |||
self.STATE_GROUP_INDEX_UPDATE_NAME, | |||
self._background_index_state, | |||
) | |||
self.register_background_index_update( | |||
self.CURRENT_STATE_INDEX_UPDATE_NAME, | |||
index_name="current_state_events_member_index", | |||
table="current_state_events", | |||
columns=["state_key"], | |||
where_clause="type='m.room.member'", | |||
) | |||
def _have_persisted_state_group_txn(self, txn, state_group): | |||
txn.execute( | |||
"SELECT count(*) FROM state_groups WHERE id = ?", | |||
(state_group,) | |||
) | |||
row = txn.fetchone() | |||
return row and row[0] | |||
def _store_mult_state_groups_txn(self, txn, events_and_contexts): | |||
state_groups = {} | |||
for event, context in events_and_contexts: | |||
if event.internal_metadata.is_outlier(): | |||
continue | |||
Args: | |||
event_id (str): The event ID for which the state was calculated | |||
room_id (str) | |||
prev_group (int|None): A previous state group for the room, optional. | |||
delta_ids (dict|None): The delta between state at `prev_group` and | |||
`current_state_ids`, if `prev_group` was given. Same format as | |||
`current_state_ids`. | |||
current_state_ids (dict): The state to store. Map of (type, state_key) | |||
to event_id. | |||
if context.current_state_ids is None: | |||
Returns: | |||
Deferred[int]: The state group ID | |||
""" | |||
def _store_state_group_txn(txn): | |||
if current_state_ids is None: | |||
# AFAIK, this can never happen | |||
logger.error( | |||
"Non-outlier event %s had current_state_ids==None", | |||
event.event_id) | |||
continue | |||
raise Exception("current_state_ids cannot be None") | |||
# if the event was rejected, just give it the same state as its | |||
# predecessor. | |||
if context.rejected: | |||
state_groups[event.event_id] = context.prev_group | |||
continue | |||
state_groups[event.event_id] = context.state_group | |||
if self._have_persisted_state_group_txn(txn, context.state_group): | |||
continue | |||
state_group = self.database_engine.get_next_state_group_id(txn) | |||
self._simple_insert_txn( | |||
txn, | |||
table="state_groups", | |||
values={ | |||
"id": context.state_group, | |||
"room_id": event.room_id, | |||
"event_id": event.event_id, | |||
"id": state_group, | |||
"room_id": room_id, | |||
"event_id": event_id, | |||
}, | |||
) | |||
# We persist as a delta if we can, while also ensuring the chain | |||
# of deltas isn't tooo long, as otherwise read performance degrades. | |||
if context.prev_group: | |||
if prev_group: | |||
is_in_db = self._simple_select_one_onecol_txn( | |||
txn, | |||
table="state_groups", | |||
keyvalues={"id": context.prev_group}, | |||
keyvalues={"id": prev_group}, | |||
retcol="id", | |||
allow_none=True, | |||
) | |||
if not is_in_db: | |||
raise Exception( | |||
"Trying to persist state with unpersisted prev_group: %r" | |||
% (context.prev_group,) | |||
% (prev_group,) | |||
) | |||
potential_hops = self._count_state_group_hops_txn( | |||
txn, context.prev_group | |||
txn, prev_group | |||
) | |||
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: | |||
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: | |||
self._simple_insert_txn( | |||
txn, | |||
table="state_group_edges", | |||
values={ | |||
"state_group": context.state_group, | |||
"prev_state_group": context.prev_group, | |||
"state_group": state_group, | |||
"prev_state_group": prev_group, | |||
}, | |||
) | |||
@@ -667,13 +614,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): | |||
table="state_groups_state", | |||
values=[ | |||
{ | |||
"state_group": context.state_group, | |||
"room_id": event.room_id, | |||
"state_group": state_group, | |||
"room_id": room_id, | |||
"type": key[0], | |||
"state_key": key[1], | |||
"event_id": state_id, | |||
} | |||
for key, state_id in context.delta_ids.iteritems() | |||
for key, state_id in delta_ids.iteritems() | |||
], | |||
) | |||
else: | |||
@@ -682,13 +629,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): | |||
table="state_groups_state", | |||
values=[ | |||
{ | |||
"state_group": context.state_group, | |||
"room_id": event.room_id, | |||
"state_group": state_group, | |||
"room_id": room_id, | |||
"type": key[0], | |||
"state_key": key[1], | |||
"event_id": state_id, | |||
} | |||
for key, state_id in context.current_state_ids.iteritems() | |||
for key, state_id in current_state_ids.iteritems() | |||
], | |||
) | |||
@@ -699,11 +646,71 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): | |||
txn.call_after( | |||
self._state_group_cache.update, | |||
self._state_group_cache.sequence, | |||
key=context.state_group, | |||
value=dict(context.current_state_ids), | |||
key=state_group, | |||
value=dict(current_state_ids), | |||
full=True, | |||
) | |||
return state_group | |||
return self.runInteraction("store_state_group", _store_state_group_txn) | |||
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): | |||
""" Keeps track of the state at a given event. | |||
This is done by the concept of `state groups`. Every event is a assigned | |||
a state group (identified by an arbitrary string), which references a | |||
collection of state events. The current state of an event is then the | |||
collection of state events referenced by the event's state group. | |||
Hence, every change in the current state causes a new state group to be | |||
generated. However, if no change happens (e.g., if we get a message event | |||
with only one parent it inherits the state group from its parent.) | |||
There are three tables: | |||
* `state_groups`: Stores group name, first event with in the group and | |||
room id. | |||
* `event_to_state_groups`: Maps events to state groups. | |||
* `state_groups_state`: Maps state group to state events. | |||
""" | |||
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" | |||
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" | |||
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" | |||
def __init__(self, db_conn, hs): | |||
super(StateStore, self).__init__(db_conn, hs) | |||
self.register_background_update_handler( | |||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, | |||
self._background_deduplicate_state, | |||
) | |||
self.register_background_update_handler( | |||
self.STATE_GROUP_INDEX_UPDATE_NAME, | |||
self._background_index_state, | |||
) | |||
self.register_background_index_update( | |||
self.CURRENT_STATE_INDEX_UPDATE_NAME, | |||
index_name="current_state_events_member_index", | |||
table="current_state_events", | |||
columns=["state_key"], | |||
where_clause="type='m.room.member'", | |||
) | |||
def _store_event_state_mappings_txn(self, txn, events_and_contexts): | |||
state_groups = {} | |||
for event, context in events_and_contexts: | |||
if event.internal_metadata.is_outlier(): | |||
continue | |||
# if the event was rejected, just give it the same state as its | |||
# predecessor. | |||
if context.rejected: | |||
state_groups[event.event_id] = context.prev_group | |||
continue | |||
state_groups[event.event_id] = context.state_group | |||
self._simple_insert_many_txn( | |||
txn, | |||
table="event_to_state_groups", | |||
@@ -763,9 +770,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): | |||
return count | |||
def get_next_state_group(self): | |||
return self._state_groups_id_gen.get_next() | |||
@defer.inlineCallbacks | |||
def _background_deduplicate_state(self, progress, batch_size): | |||
"""This background update will slowly deduplicate state by reencoding | |||
@@ -226,11 +226,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): | |||
context = EventContext() | |||
context.current_state_ids = state_ids | |||
context.prev_state_ids = state_ids | |||
elif not backfill: | |||
else: | |||
state_handler = self.hs.get_state_handler() | |||
context = yield state_handler.compute_event_context(event) | |||
else: | |||
context = EventContext() | |||
context.push_actions = push_actions | |||
@@ -80,14 +80,14 @@ class StateGroupStore(object): | |||
return defer.succeed(groups) | |||
def store_state_groups(self, event, context): | |||
if context.current_state_ids is None: | |||
return | |||
def store_state_group(self, event_id, room_id, prev_group, delta_ids, | |||
current_state_ids): | |||
state_group = self._next_group | |||
self._next_group += 1 | |||
state_events = dict(context.current_state_ids) | |||
self._group_to_state[state_group] = dict(current_state_ids) | |||
self._group_to_state[context.state_group] = state_events | |||
self._event_to_state_group[event.event_id] = context.state_group | |||
return state_group | |||
def get_events(self, event_ids, **kwargs): | |||
return { | |||
@@ -95,10 +95,19 @@ class StateGroupStore(object): | |||
if e_id in self._event_id_to_event | |||
} | |||
def get_state_group_delta(self, name): | |||
return (None, None) | |||
def register_events(self, events): | |||
for e in events: | |||
self._event_id_to_event[e.event_id] = e | |||
def register_event_context(self, event, context): | |||
self._event_to_state_group[event.event_id] = context.state_group | |||
def register_event_id_state_group(self, event_id, state_group): | |||
self._event_to_state_group[event_id] = state_group | |||
class DictObj(dict): | |||
def __init__(self, **kwargs): | |||
@@ -137,15 +146,7 @@ class Graph(object): | |||
class StateTestCase(unittest.TestCase): | |||
def setUp(self): | |||
self.store = Mock( | |||
spec_set=[ | |||
"get_state_groups_ids", | |||
"add_event_hashes", | |||
"get_events", | |||
"get_next_state_group", | |||
"get_state_group_delta", | |||
] | |||
) | |||
self.store = StateGroupStore() | |||
hs = Mock(spec_set=[ | |||
"get_datastore", "get_auth", "get_state_handler", "get_clock", | |||
"get_state_resolution_handler", | |||
@@ -156,9 +157,6 @@ class StateTestCase(unittest.TestCase): | |||
hs.get_auth.return_value = Auth(hs) | |||
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) | |||
self.store.get_next_state_group.side_effect = Mock | |||
self.store.get_state_group_delta.return_value = (None, None) | |||
self.state = StateHandler(hs) | |||
self.event_id = 0 | |||
@@ -197,14 +195,13 @@ class StateTestCase(unittest.TestCase): | |||
} | |||
) | |||
store = StateGroupStore() | |||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids | |||
self.store.register_events(graph.walk()) | |||
context_store = {} | |||
for event in graph.walk(): | |||
context = yield self.state.compute_event_context(event) | |||
store.store_state_groups(event, context) | |||
self.store.register_event_context(event, context) | |||
context_store[event.event_id] = context | |||
self.assertEqual(2, len(context_store["D"].prev_state_ids)) | |||
@@ -249,16 +246,13 @@ class StateTestCase(unittest.TestCase): | |||
} | |||
) | |||
store = StateGroupStore() | |||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids | |||
self.store.get_events = store.get_events | |||
store.register_events(graph.walk()) | |||
self.store.register_events(graph.walk()) | |||
context_store = {} | |||
for event in graph.walk(): | |||
context = yield self.state.compute_event_context(event) | |||
store.store_state_groups(event, context) | |||
self.store.register_event_context(event, context) | |||
context_store[event.event_id] = context | |||
self.assertSetEqual( | |||
@@ -315,16 +309,13 @@ class StateTestCase(unittest.TestCase): | |||
} | |||
) | |||
store = StateGroupStore() | |||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids | |||
self.store.get_events = store.get_events | |||
store.register_events(graph.walk()) | |||
self.store.register_events(graph.walk()) | |||
context_store = {} | |||
for event in graph.walk(): | |||
context = yield self.state.compute_event_context(event) | |||
store.store_state_groups(event, context) | |||
self.store.register_event_context(event, context) | |||
context_store[event.event_id] = context | |||
self.assertSetEqual( | |||
@@ -398,16 +389,13 @@ class StateTestCase(unittest.TestCase): | |||
self._add_depths(nodes, edges) | |||
graph = Graph(nodes, edges) | |||
store = StateGroupStore() | |||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids | |||
self.store.get_events = store.get_events | |||
store.register_events(graph.walk()) | |||
self.store.register_events(graph.walk()) | |||
context_store = {} | |||
for event in graph.walk(): | |||
context = yield self.state.compute_event_context(event) | |||
store.store_state_groups(event, context) | |||
self.store.register_event_context(event, context) | |||
context_store[event.event_id] = context | |||
self.assertSetEqual( | |||
@@ -467,7 +455,11 @@ class StateTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def test_trivial_annotate_message(self): | |||
event = create_event(type="test_message", name="event") | |||
prev_event_id = "prev_event_id" | |||
event = create_event( | |||
type="test_message", name="event2", | |||
prev_events=[(prev_event_id, {})], | |||
) | |||
old_state = [ | |||
create_event(type="test1", state_key="1"), | |||
@@ -475,11 +467,11 @@ class StateTestCase(unittest.TestCase): | |||
create_event(type="test2", state_key=""), | |||
] | |||
group_name = "group_name_1" | |||
self.store.get_state_groups_ids.return_value = { | |||
group_name: {(e.type, e.state_key): e.event_id for e in old_state}, | |||
} | |||
group_name = self.store.store_state_group( | |||
prev_event_id, event.room_id, None, None, | |||
{(e.type, e.state_key): e.event_id for e in old_state}, | |||
) | |||
self.store.register_event_id_state_group(prev_event_id, group_name) | |||
context = yield self.state.compute_event_context(event) | |||
@@ -492,7 +484,11 @@ class StateTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def test_trivial_annotate_state(self): | |||
event = create_event(type="state", state_key="", name="event") | |||
prev_event_id = "prev_event_id" | |||
event = create_event( | |||
type="state", state_key="", name="event2", | |||
prev_events=[(prev_event_id, {})], | |||
) | |||
old_state = [ | |||
create_event(type="test1", state_key="1"), | |||
@@ -500,11 +496,11 @@ class StateTestCase(unittest.TestCase): | |||
create_event(type="test2", state_key=""), | |||
] | |||
group_name = "group_name_1" | |||
self.store.get_state_groups_ids.return_value = { | |||
group_name: {(e.type, e.state_key): e.event_id for e in old_state}, | |||
} | |||
group_name = self.store.store_state_group( | |||
prev_event_id, event.room_id, None, None, | |||
{(e.type, e.state_key): e.event_id for e in old_state}, | |||
) | |||
self.store.register_event_id_state_group(prev_event_id, group_name) | |||
context = yield self.state.compute_event_context(event) | |||
@@ -517,7 +513,12 @@ class StateTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def test_resolve_message_conflict(self): | |||
event = create_event(type="test_message", name="event") | |||
prev_event_id1 = "event_id1" | |||
prev_event_id2 = "event_id2" | |||
event = create_event( | |||
type="test_message", name="event3", | |||
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], | |||
) | |||
creation = create_event( | |||
type=EventTypes.Create, state_key="" | |||
@@ -537,12 +538,12 @@ class StateTestCase(unittest.TestCase): | |||
create_event(type="test4", state_key=""), | |||
] | |||
store = StateGroupStore() | |||
store.register_events(old_state_1) | |||
store.register_events(old_state_2) | |||
self.store.get_events = store.get_events | |||
self.store.register_events(old_state_1) | |||
self.store.register_events(old_state_2) | |||
context = yield self._get_context(event, old_state_1, old_state_2) | |||
context = yield self._get_context( | |||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, | |||
) | |||
self.assertEqual(len(context.current_state_ids), 6) | |||
@@ -550,7 +551,12 @@ class StateTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def test_resolve_state_conflict(self): | |||
event = create_event(type="test4", state_key="", name="event") | |||
prev_event_id1 = "event_id1" | |||
prev_event_id2 = "event_id2" | |||
event = create_event( | |||
type="test4", state_key="", name="event", | |||
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], | |||
) | |||
creation = create_event( | |||
type=EventTypes.Create, state_key="" | |||
@@ -575,7 +581,9 @@ class StateTestCase(unittest.TestCase): | |||
store.register_events(old_state_2) | |||
self.store.get_events = store.get_events | |||
context = yield self._get_context(event, old_state_1, old_state_2) | |||
context = yield self._get_context( | |||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, | |||
) | |||
self.assertEqual(len(context.current_state_ids), 6) | |||
@@ -583,7 +591,12 @@ class StateTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def test_standard_depth_conflict(self): | |||
event = create_event(type="test4", name="event") | |||
prev_event_id1 = "event_id1" | |||
prev_event_id2 = "event_id2" | |||
event = create_event( | |||
type="test4", name="event", | |||
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], | |||
) | |||
member_event = create_event( | |||
type=EventTypes.Member, | |||
@@ -615,7 +628,9 @@ class StateTestCase(unittest.TestCase): | |||
store.register_events(old_state_2) | |||
self.store.get_events = store.get_events | |||
context = yield self._get_context(event, old_state_1, old_state_2) | |||
context = yield self._get_context( | |||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, | |||
) | |||
self.assertEqual( | |||
old_state_2[2].event_id, context.current_state_ids[("test1", "1")] | |||
@@ -639,19 +654,26 @@ class StateTestCase(unittest.TestCase): | |||
store.register_events(old_state_1) | |||
store.register_events(old_state_2) | |||
context = yield self._get_context(event, old_state_1, old_state_2) | |||
context = yield self._get_context( | |||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, | |||
) | |||
self.assertEqual( | |||
old_state_1[2].event_id, context.current_state_ids[("test1", "1")] | |||
) | |||
def _get_context(self, event, old_state_1, old_state_2): | |||
group_name_1 = "group_name_1" | |||
group_name_2 = "group_name_2" | |||
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, | |||
old_state_2): | |||
sg1 = self.store.store_state_group( | |||
prev_event_id_1, event.room_id, None, None, | |||
{(e.type, e.state_key): e.event_id for e in old_state_1}, | |||
) | |||
self.store.register_event_id_state_group(prev_event_id_1, sg1) | |||
self.store.get_state_groups_ids.return_value = { | |||
group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1}, | |||
group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2}, | |||
} | |||
sg2 = self.store.store_state_group( | |||
prev_event_id_2, event.room_id, None, None, | |||
{(e.type, e.state_key): e.event_id for e in old_state_2}, | |||
) | |||
self.store.register_event_id_state_group(prev_event_id_2, sg2) | |||
return self.state.compute_event_context(event) |