# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an 'AS IS' BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Set, Tuple from twisted.trial import unittest from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin from synapse.rest.client import login, room from synapse.storage.databases.main.events import _LinkMap from synapse.types import create_requester from tests.unittest import HomeserverTestCase class EventChainStoreTestCase(HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main self._next_stream_ordering = 1 def test_simple(self): """Test that the example in `docs/auth_chain_difference_algorithm.md` works. """ event_factory = self.hs.get_event_builder_factory() bob = "@creator:test" alice = "@alice:test" room_id = "!room:test" # Ensure that we have a rooms entry so that we generate the chain index. self.get_success( self.store.store_room( room_id=room_id, room_creator_user_id="", is_public=True, room_version=RoomVersions.V6, ) ) create = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.Create, "state_key": "", "sender": bob, "room_id": room_id, "content": {"tag": "create"}, }, ).build(prev_event_ids=[], auth_event_ids=[]) ) bob_join = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.Member, "state_key": bob, "sender": bob, "room_id": room_id, "content": {"tag": "bob_join"}, }, ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) ) power = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.PowerLevels, "state_key": "", "sender": bob, "room_id": room_id, "content": {"tag": "power"}, }, ).build( prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], ) ) alice_invite = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.Member, "state_key": alice, "sender": bob, "room_id": room_id, "content": {"tag": "alice_invite"}, }, ).build( prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], ) ) alice_join = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.Member, "state_key": alice, "sender": alice, "room_id": room_id, "content": {"tag": "alice_join"}, }, ).build( prev_event_ids=[], auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], ) ) power_2 = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.PowerLevels, "state_key": "", "sender": bob, "room_id": room_id, "content": {"tag": "power_2"}, }, ).build( prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], ) ) bob_join_2 = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.Member, "state_key": bob, "sender": bob, "room_id": room_id, "content": {"tag": "bob_join_2"}, }, ).build( prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], ) ) alice_join2 = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.Member, "state_key": alice, "sender": alice, "room_id": room_id, "content": {"tag": "alice_join2"}, }, ).build( prev_event_ids=[], auth_event_ids=[ create.event_id, alice_join.event_id, power_2.event_id, ], ) ) events = [ create, bob_join, power, alice_invite, alice_join, bob_join_2, power_2, alice_join2, ] expected_links = [ (bob_join, create), (power, create), (power, bob_join), (alice_invite, create), (alice_invite, power), (alice_invite, bob_join), (bob_join_2, power), (alice_join2, power_2), ] self.persist(events) chain_map, link_map = self.fetch_chains(events) # Check that the expected links and only the expected links have been # added. self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) for start, end in expected_links: start_id, start_seq = chain_map[start.event_id] end_id, end_seq = chain_map[end.event_id] self.assertIn( (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) ) # Test that everything can reach the create event, but the create event # can't reach anything. for event in events[1:]: self.assertTrue( link_map.exists_path_from( chain_map[event.event_id], chain_map[create.event_id] ), ) self.assertFalse( link_map.exists_path_from( chain_map[create.event_id], chain_map[event.event_id], ), ) def test_out_of_order_events(self): """Test that we handle persisting events that we don't have the full auth chain for yet (which should only happen for out of band memberships). """ event_factory = self.hs.get_event_builder_factory() bob = "@creator:test" alice = "@alice:test" room_id = "!room:test" # Ensure that we have a rooms entry so that we generate the chain index. self.get_success( self.store.store_room( room_id=room_id, room_creator_user_id="", is_public=True, room_version=RoomVersions.V6, ) ) # First persist the base room. create = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.Create, "state_key": "", "sender": bob, "room_id": room_id, "content": {"tag": "create"}, }, ).build(prev_event_ids=[], auth_event_ids=[]) ) bob_join = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.Member, "state_key": bob, "sender": bob, "room_id": room_id, "content": {"tag": "bob_join"}, }, ).build(prev_event_ids=[], auth_event_ids=[create.event_id]) ) power = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.PowerLevels, "state_key": "", "sender": bob, "room_id": room_id, "content": {"tag": "power"}, }, ).build( prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id], ) ) self.persist([create, bob_join, power]) # Now persist an invite and a couple of memberships out of order. alice_invite = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.Member, "state_key": alice, "sender": bob, "room_id": room_id, "content": {"tag": "alice_invite"}, }, ).build( prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id, power.event_id], ) ) alice_join = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.Member, "state_key": alice, "sender": alice, "room_id": room_id, "content": {"tag": "alice_join"}, }, ).build( prev_event_ids=[], auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id], ) ) alice_join2 = self.get_success( event_factory.for_room_version( RoomVersions.V6, { "type": EventTypes.Member, "state_key": alice, "sender": alice, "room_id": room_id, "content": {"tag": "alice_join2"}, }, ).build( prev_event_ids=[], auth_event_ids=[create.event_id, alice_join.event_id, power.event_id], ) ) self.persist([alice_join]) self.persist([alice_join2]) self.persist([alice_invite]) # The end result should be sane. events = [create, bob_join, power, alice_invite, alice_join] chain_map, link_map = self.fetch_chains(events) expected_links = [ (bob_join, create), (power, create), (power, bob_join), (alice_invite, create), (alice_invite, power), (alice_invite, bob_join), ] # Check that the expected links and only the expected links have been # added. self.assertEqual(len(expected_links), len(list(link_map.get_additions()))) for start, end in expected_links: start_id, start_seq = chain_map[start.event_id] end_id, end_seq = chain_map[end.event_id] self.assertIn( (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id)) ) def persist( self, events: List[EventBase], ): """Persist the given events and check that the links generated match those given. """ persist_events_store = self.hs.get_datastores().persist_events for e in events: e.internal_metadata.stream_ordering = self._next_stream_ordering self._next_stream_ordering += 1 def _persist(txn): # We need to persist the events to the events and state_events # tables. persist_events_store._store_event_txn( txn, [(e, EventContext(self.hs.get_storage_controllers())) for e in events], ) # Actually call the function that calculates the auth chain stuff. persist_events_store._persist_event_auth_chain_txn(txn, events) self.get_success( persist_events_store.db_pool.runInteraction( "_persist", _persist, ) ) def fetch_chains( self, events: List[EventBase] ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]: # Fetch the map from event ID -> (chain ID, sequence number) rows = self.get_success( self.store.db_pool.simple_select_many_batch( table="event_auth_chains", column="event_id", iterable=[e.event_id for e in events], retcols=("event_id", "chain_id", "sequence_number"), keyvalues={}, ) ) chain_map = { row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows } # Fetch all the links and pass them to the _LinkMap. rows = self.get_success( self.store.db_pool.simple_select_many_batch( table="event_auth_chain_links", column="origin_chain_id", iterable=[chain_id for chain_id, _ in chain_map.values()], retcols=( "origin_chain_id", "origin_sequence_number", "target_chain_id", "target_sequence_number", ), keyvalues={}, ) ) link_map = _LinkMap() for row in rows: added = link_map.add_link( (row["origin_chain_id"], row["origin_sequence_number"]), (row["target_chain_id"], row["target_sequence_number"]), ) # We shouldn't have persisted any redundant links self.assertTrue(added) return chain_map, link_map class LinkMapTestCase(unittest.TestCase): def test_simple(self): """Basic tests for the LinkMap.""" link_map = _LinkMap() link_map.add_link((1, 1), (2, 1), new=False) self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)]) self.assertCountEqual(link_map.get_additions(), []) self.assertTrue(link_map.exists_path_from((1, 5), (2, 1))) self.assertFalse(link_map.exists_path_from((1, 5), (2, 2))) self.assertTrue(link_map.exists_path_from((1, 5), (1, 1))) self.assertFalse(link_map.exists_path_from((1, 1), (1, 5))) # Attempting to add a redundant link is ignored. self.assertFalse(link_map.add_link((1, 4), (2, 1))) self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)]) # Adding new non-redundant links works self.assertTrue(link_map.add_link((1, 3), (2, 3))) self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) self.assertTrue(link_map.add_link((2, 5), (1, 3))) self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)]) self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) class EventChainBackgroundUpdateTestCase(HomeserverTestCase): servlets = [ admin.register_servlets, room.register_servlets, login.register_servlets, ] def prepare(self, reactor, clock, hs): self.store = hs.get_datastores().main self.user_id = self.register_user("foo", "pass") self.token = self.login("foo", "pass") self.requester = create_requester(self.user_id) def _generate_room(self) -> Tuple[str, List[Set[str]]]: """Insert a room without a chain cover index.""" room_id = self.helper.create_room_as(self.user_id, tok=self.token) # Mark the room as not having a chain cover index self.get_success( self.store.db_pool.simple_update( table="rooms", keyvalues={"room_id": room_id}, updatevalues={"has_auth_chain_index": False}, desc="test", ) ) # Create a fork in the DAG with different events. event_handler = self.hs.get_event_creation_handler() latest_event_ids = self.get_success( self.store.get_prev_events_for_room(room_id) ) event, context = self.get_success( event_handler.create_event( self.requester, { "type": "some_state_type", "state_key": "", "content": {}, "room_id": room_id, "sender": self.user_id, }, prev_event_ids=latest_event_ids, ) ) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] ) ) state1 = set(self.get_success(context.get_current_state_ids()).values()) event, context = self.get_success( event_handler.create_event( self.requester, { "type": "some_state_type", "state_key": "", "content": {}, "room_id": room_id, "sender": self.user_id, }, prev_event_ids=latest_event_ids, ) ) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] ) ) state2 = set(self.get_success(context.get_current_state_ids()).values()) # Delete the chain cover info. def _delete_tables(txn): txn.execute("DELETE FROM event_auth_chains") txn.execute("DELETE FROM event_auth_chain_links") self.get_success(self.store.db_pool.runInteraction("test", _delete_tables)) return room_id, [state1, state2] def test_background_update_single_room(self): """Test that the background update to calculate auth chains for historic rooms works correctly. """ # Create a room room_id, states = self._generate_room() # Insert and run the background update. self.get_success( self.store.db_pool.simple_insert( "background_updates", {"update_name": "chain_cover", "progress_json": "{}"}, ) ) # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False self.wait_for_background_updates() # Test that the `has_auth_chain_index` has been set self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) # Test that calculating the auth chain difference using the newly # calculated chain cover works. self.get_success( self.store.db_pool.runInteraction( "test", self.store._get_auth_chain_difference_using_cover_index_txn, room_id, states, ) ) def test_background_update_multiple_rooms(self): """Test that the background update to calculate auth chains for historic rooms works correctly. """ # Create a room room_id1, states1 = self._generate_room() room_id2, states2 = self._generate_room() room_id3, states2 = self._generate_room() # Insert and run the background update. self.get_success( self.store.db_pool.simple_insert( "background_updates", {"update_name": "chain_cover", "progress_json": "{}"}, ) ) # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False self.wait_for_background_updates() # Test that the `has_auth_chain_index` has been set self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2))) self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3))) # Test that calculating the auth chain difference using the newly # calculated chain cover works. self.get_success( self.store.db_pool.runInteraction( "test", self.store._get_auth_chain_difference_using_cover_index_txn, room_id1, states1, ) ) def test_background_update_single_large_room(self): """Test that the background update to calculate auth chains for historic rooms works correctly. """ # Create a room room_id, states = self._generate_room() # Add a bunch of state so that it takes multiple iterations of the # background update to process the room. for i in range(0, 150): self.helper.send_state( room_id, event_type="m.test", body={"index": i}, tok=self.token ) # Insert and run the background update. self.get_success( self.store.db_pool.simple_insert( "background_updates", {"update_name": "chain_cover", "progress_json": "{}"}, ) ) # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False iterations = 0 while not self.get_success( self.store.db_pool.updates.has_completed_background_updates() ): iterations += 1 self.get_success( self.store.db_pool.updates.do_next_background_update(False), by=0.1 ) # Ensure that we did actually take multiple iterations to process the # room. self.assertGreater(iterations, 1) # Test that the `has_auth_chain_index` has been set self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) # Test that calculating the auth chain difference using the newly # calculated chain cover works. self.get_success( self.store.db_pool.runInteraction( "test", self.store._get_auth_chain_difference_using_cover_index_txn, room_id, states, ) ) def test_background_update_multiple_large_room(self): """Test that the background update to calculate auth chains for historic rooms works correctly. """ # Create the rooms room_id1, _ = self._generate_room() room_id2, _ = self._generate_room() # Add a bunch of state so that it takes multiple iterations of the # background update to process the room. for i in range(0, 150): self.helper.send_state( room_id1, event_type="m.test", body={"index": i}, tok=self.token ) for i in range(0, 150): self.helper.send_state( room_id2, event_type="m.test", body={"index": i}, tok=self.token ) # Insert and run the background update. self.get_success( self.store.db_pool.simple_insert( "background_updates", {"update_name": "chain_cover", "progress_json": "{}"}, ) ) # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False iterations = 0 while not self.get_success( self.store.db_pool.updates.has_completed_background_updates() ): iterations += 1 self.get_success( self.store.db_pool.updates.do_next_background_update(False), by=0.1 ) # Ensure that we did actually take multiple iterations to process the # room. self.assertGreater(iterations, 1) # Test that the `has_auth_chain_index` has been set self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))