|
|
@@ -67,6 +67,8 @@ class StateGroupStore(object): |
|
|
|
self._event_to_state_group = {} |
|
|
|
self._group_to_state = {} |
|
|
|
|
|
|
|
self._event_id_to_event = {} |
|
|
|
|
|
|
|
self._next_group = 1 |
|
|
|
|
|
|
|
def get_state_groups_ids(self, room_id, event_ids): |
|
|
@@ -96,6 +98,16 @@ class StateGroupStore(object): |
|
|
|
|
|
|
|
self._event_to_state_group[event.event_id] = state_group |
|
|
|
|
|
|
|
def get_events(self, event_ids, **kwargs): |
|
|
|
return { |
|
|
|
e_id: self._event_id_to_event[e_id] for e_id in event_ids |
|
|
|
if e_id in self._event_id_to_event |
|
|
|
} |
|
|
|
|
|
|
|
def register_events(self, events): |
|
|
|
for e in events: |
|
|
|
self._event_id_to_event[e.event_id] = e |
|
|
|
|
|
|
|
|
|
|
|
class DictObj(dict): |
|
|
|
def __init__(self, **kwargs): |
|
|
@@ -138,6 +150,7 @@ class StateTestCase(unittest.TestCase): |
|
|
|
spec_set=[ |
|
|
|
"get_state_groups_ids", |
|
|
|
"add_event_hashes", |
|
|
|
"get_events", |
|
|
|
] |
|
|
|
) |
|
|
|
hs = Mock(spec_set=[ |
|
|
@@ -240,6 +253,8 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
|
|
|
store = StateGroupStore() |
|
|
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids |
|
|
|
self.store.get_events = store.get_events |
|
|
|
store.register_events(graph.walk()) |
|
|
|
|
|
|
|
context_store = {} |
|
|
|
|
|
|
@@ -250,7 +265,7 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
|
|
|
self.assertSetEqual( |
|
|
|
{"START", "A", "C"}, |
|
|
|
{e.event_id for e in context_store["D"].current_state.values()} |
|
|
|
{e_id for e_id in context_store["D"].current_state_ids.values()} |
|
|
|
) |
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
@@ -304,6 +319,8 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
|
|
|
store = StateGroupStore() |
|
|
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids |
|
|
|
self.store.get_events = store.get_events |
|
|
|
store.register_events(graph.walk()) |
|
|
|
|
|
|
|
context_store = {} |
|
|
|
|
|
|
@@ -314,7 +331,7 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
|
|
|
self.assertSetEqual( |
|
|
|
{"START", "A", "B", "C"}, |
|
|
|
{e.event_id for e in context_store["E"].current_state.values()} |
|
|
|
{e for e in context_store["E"].current_state_ids.values()} |
|
|
|
) |
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
@@ -385,6 +402,8 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
|
|
|
store = StateGroupStore() |
|
|
|
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids |
|
|
|
self.store.get_events = store.get_events |
|
|
|
store.register_events(graph.walk()) |
|
|
|
|
|
|
|
context_store = {} |
|
|
|
|
|
|
@@ -395,7 +414,7 @@ class StateTestCase(unittest.TestCase): |
|
|
|
|
|
|
|
self.assertSetEqual( |
|
|
|
{"A1", "A2", "A3", "A5", "B"}, |
|
|
|
{e.event_id for e in context_store["D"].current_state.values()} |
|
|
|
{e for e in context_store["D"].current_state_ids.values()} |
|
|
|
) |
|
|
|
|
|
|
|
def _add_depths(self, nodes, edges): |
|
|
@@ -522,6 +541,11 @@ class StateTestCase(unittest.TestCase): |
|
|
|
create_event(type="test4", state_key=""), |
|
|
|
] |
|
|
|
|
|
|
|
store = StateGroupStore() |
|
|
|
store.register_events(old_state_1) |
|
|
|
store.register_events(old_state_2) |
|
|
|
self.store.get_events = store.get_events |
|
|
|
|
|
|
|
context = yield self._get_context(event, old_state_1, old_state_2) |
|
|
|
|
|
|
|
self.assertEqual(len(context.current_state_ids), 6) |
|
|
@@ -550,6 +574,11 @@ class StateTestCase(unittest.TestCase): |
|
|
|
create_event(type="test4", state_key=""), |
|
|
|
] |
|
|
|
|
|
|
|
store = StateGroupStore() |
|
|
|
store.register_events(old_state_1) |
|
|
|
store.register_events(old_state_2) |
|
|
|
self.store.get_events = store.get_events |
|
|
|
|
|
|
|
context = yield self._get_context(event, old_state_1, old_state_2) |
|
|
|
|
|
|
|
self.assertEqual(len(context.current_state_ids), 6) |
|
|
@@ -585,9 +614,16 @@ class StateTestCase(unittest.TestCase): |
|
|
|
create_event(type="test1", state_key="1", depth=2), |
|
|
|
] |
|
|
|
|
|
|
|
store = StateGroupStore() |
|
|
|
store.register_events(old_state_1) |
|
|
|
store.register_events(old_state_2) |
|
|
|
self.store.get_events = store.get_events |
|
|
|
|
|
|
|
context = yield self._get_context(event, old_state_1, old_state_2) |
|
|
|
|
|
|
|
self.assertEqual(old_state_2[2].event.id, context.current_state_ids[("test1", "1")]) |
|
|
|
self.assertEqual( |
|
|
|
old_state_2[2].event_id, context.current_state_ids[("test1", "1")] |
|
|
|
) |
|
|
|
|
|
|
|
# Reverse the depth to make sure we are actually using the depths |
|
|
|
# during state resolution. |
|
|
@@ -604,9 +640,14 @@ class StateTestCase(unittest.TestCase): |
|
|
|
create_event(type="test1", state_key="1", depth=1), |
|
|
|
] |
|
|
|
|
|
|
|
store.register_events(old_state_1) |
|
|
|
store.register_events(old_state_2) |
|
|
|
|
|
|
|
context = yield self._get_context(event, old_state_1, old_state_2) |
|
|
|
|
|
|
|
self.assertEqual(old_state_1[2].event_id, context.current_state_ids[("test1", "1")]) |
|
|
|
self.assertEqual( |
|
|
|
old_state_1[2].event_id, context.current_state_ids[("test1", "1")] |
|
|
|
) |
|
|
|
|
|
|
|
def _get_context(self, event, old_state_1, old_state_2): |
|
|
|
group_name_1 = "group_name_1" |
|
|
|