diff --git a/changelog.d/16326.misc b/changelog.d/16326.misc new file mode 100644 index 0000000000..93ceaeafc9 --- /dev/null +++ b/changelog.d/16326.misc @@ -0,0 +1 @@ +Improve type hints. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 1165c017ba..43469b170f 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -103,7 +103,7 @@ class EventBuilder: async def build( self, - prev_event_ids: StrCollection, + prev_event_ids: List[str], auth_event_ids: Optional[List[str]], depth: Optional[int] = None, ) -> EventBase: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index d32d224d56..eedde97ab0 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -723,12 +723,11 @@ class FederationEventHandler: if not prevs - seen: return - latest_list = await self._store.get_latest_event_ids_in_room(room_id) + latest_frozen = await self._store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us - latest = set(latest_list) - latest |= seen + latest = seen | latest_frozen logger.info( "Requesting missing events between %s and %s", @@ -1976,8 +1975,7 @@ class FederationEventHandler: # partial and full state and may not be accurate. return - extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) - extrem_ids = set(extrem_ids_list) + extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id) prev_event_ids = set(event.prev_event_ids()) if extrem_ids == prev_event_ids: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 6864f93090..f39ae2d635 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -19,6 +19,7 @@ import logging from collections import deque from typing import ( TYPE_CHECKING, + AbstractSet, Any, Awaitable, Callable, @@ -618,7 +619,7 @@ class EventsPersistenceStorageController: ) for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = set( + latest_event_ids = ( await self.main_store.get_latest_event_ids_in_room(room_id) ) new_latest_event_ids = await self._calculate_new_extremities( @@ -740,7 +741,7 @@ class EventsPersistenceStorageController: self, room_id: str, event_contexts: List[Tuple[EventBase, EventContext]], - latest_event_ids: Collection[str], + latest_event_ids: AbstractSet[str], ) -> Set[str]: """Calculates the new forward extremities for a room given events to persist. @@ -758,8 +759,6 @@ class EventsPersistenceStorageController: and not event.internal_metadata.is_soft_failed() ] - latest_event_ids = set(latest_event_ids) - # start with the existing forward extremities result = set(latest_event_ids) @@ -798,7 +797,7 @@ class EventsPersistenceStorageController: self, room_id: str, events_context: List[Tuple[EventBase, EventContext]], - old_latest_event_ids: Set[str], + old_latest_event_ids: AbstractSet[str], new_latest_event_ids: Set[str], ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]: """Calculate the current state dict after adding some new events to diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 09de8f55e2..afffa54985 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -19,6 +19,7 @@ from typing import ( TYPE_CHECKING, Collection, Dict, + FrozenSet, Iterable, List, Optional, @@ -47,7 +48,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import JsonDict, StrCollection, StrSequence +from synapse.types import JsonDict, StrCollection from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -1179,13 +1180,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) @cached(max_entries=5000, iterable=True) - async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence: - return await self.db_pool.simple_select_onecol( + async def get_latest_event_ids_in_room(self, room_id: str) -> FrozenSet[str]: + event_ids = await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", desc="get_latest_event_ids_in_room", ) + return frozenset(event_ids) async def get_min_depth(self, room_id: str) -> Optional[int]: """For the given room, get the minimum depth we have seen for it.""" diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 0c1ed75240..bc8474a589 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -222,7 +222,7 @@ class PersistEventsStore: for room_id, latest_event_ids in new_forward_extremities.items(): self.store.get_latest_event_ids_in_room.prefill( - (room_id,), list(latest_event_ids) + (room_id,), frozenset(latest_event_ids) ) async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 638787b029..41c8c44e02 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -1858,7 +1858,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) event = self.get_success( - builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) + builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None) ) self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py index af25815fa5..33c277a38a 100644 --- a/tests/replication/storage/test_events.py +++ b/tests/replication/storage/test_events.py @@ -90,7 +90,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): def test_get_latest_event_ids_in_room(self) -> None: create = self.persist(type="m.room.create", key="", creator=USER_ID) self.replicate() - self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) + self.check("get_latest_event_ids_in_room", (ROOM_ID,), {create.event_id}) join = self.persist( type="m.room.member", @@ -99,7 +99,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): prev_events=[(create.event_id, {})], ) self.replicate() - self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) + self.check("get_latest_event_ids_in_room", (ROOM_ID,), {join.event_id}) def test_redactions(self) -> None: self.persist(type="m.room.create", key="", creator=USER_ID) diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 65ef4bb160..128fc3e046 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional from twisted.test.proto_helpers import MemoryReactor @@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point: Sequence[str] = self.get_success( + fork_point = self.get_success( self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) @@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point: Sequence[str] = self.get_success( + fork_point = self.get_success( self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) @@ -316,14 +316,14 @@ class EventsStreamTestCase(BaseStreamTestCase): self.test_handler.received_rdata_rows.clear() # now roll back all that state by de-modding the users - prev_events = fork_point + prev_events = list(fork_point) pl_events = [] for u in user_ids: pls["users"][u] = 0 e = self.get_success( inject_event( self.hs, - prev_event_ids=list(prev_events), + prev_event_ids=prev_events, type=EventTypes.PowerLevels, state_key="", sender=self.user_id, diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 9b28cd474f..59f4fdc70b 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -261,7 +261,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): builder = factory.for_room_version(room_version, event_dict) join_event = self.get_success( - builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) + builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None) ) self.get_success(federation.on_send_membership_event(remote_server, join_event)) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 7de109966d..ceb9597dd3 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -120,7 +120,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(latest_event_ids, [event_id_4]) + self.assertEqual(latest_event_ids, {event_id_4}) def test_basic_cleanup(self) -> None: """Test that extremities are correctly calculated in the presence of @@ -147,7 +147,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) + self.assertEqual(latest_event_ids, {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -155,7 +155,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(latest_event_ids, [event_id_b]) + self.assertEqual(latest_event_ids, {event_id_b}) def test_chain_of_fail_cleanup(self) -> None: """Test that extremities are correctly calculated in the presence of @@ -185,7 +185,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) + self.assertEqual(latest_event_ids, {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -193,7 +193,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(latest_event_ids, [event_id_b]) + self.assertEqual(latest_event_ids, {event_id_b}) def test_forked_graph_cleanup(self) -> None: r"""Test that extremities are correctly calculated in the presence of @@ -240,7 +240,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c}) + self.assertEqual(latest_event_ids, {event_id_a, event_id_b, event_id_c}) # Run the background update and check it did the right thing self.run_background_update() @@ -248,7 +248,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c}) + self.assertEqual(latest_event_ids, {event_id_b, event_id_c}) class CleanupExtremDummyEventsTestCase(HomeserverTestCase): diff --git a/tests/test_federation.py b/tests/test_federation.py index f8ade6da38..1b0504709e 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -51,9 +51,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.store = self.hs.get_datastores().main # Figure out what the most recent event is - most_recent = self.get_success( - self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) - )[0] + most_recent = next( + iter( + self.get_success( + self.hs.get_datastores().main.get_latest_event_ids_in_room( + self.room_id + ) + ) + ) + ) join_event = make_event_from_dict( { @@ -100,8 +106,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Make sure we actually joined the room self.assertEqual( - self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0], - "$join:test.serv", + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)), + {"$join:test.serv"}, ) def test_cant_hide_direct_ancestors(self) -> None: @@ -127,9 +133,11 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.http_client.post_json = post_json # Figure out what the most recent event is - most_recent = self.get_success( - self.store.get_latest_event_ids_in_room(self.room_id) - )[0] + most_recent = next( + iter( + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) + ) + ) # Now lie about an event lying_event = make_event_from_dict( @@ -165,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Make sure the invalid event isn't there extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) - self.assertEqual(extrem[0], "$join:test.serv") + self.assertEqual(extrem, {"$join:test.serv"}) def test_retry_device_list_resync(self) -> None: """Tests that device lists are marked as stale if they couldn't be synced, and