@@ -0,0 +1 @@ | |||||
Improve type hints. |
@@ -103,7 +103,7 @@ class EventBuilder: | |||||
async def build( | async def build( | ||||
self, | self, | ||||
prev_event_ids: StrCollection, | |||||
prev_event_ids: List[str], | |||||
auth_event_ids: Optional[List[str]], | auth_event_ids: Optional[List[str]], | ||||
depth: Optional[int] = None, | depth: Optional[int] = None, | ||||
) -> EventBase: | ) -> EventBase: | ||||
@@ -723,12 +723,11 @@ class FederationEventHandler: | |||||
if not prevs - seen: | if not prevs - seen: | ||||
return | return | ||||
latest_list = await self._store.get_latest_event_ids_in_room(room_id) | |||||
latest_frozen = await self._store.get_latest_event_ids_in_room(room_id) | |||||
# We add the prev events that we have seen to the latest | # We add the prev events that we have seen to the latest | ||||
# list to ensure the remote server doesn't give them to us | # list to ensure the remote server doesn't give them to us | ||||
latest = set(latest_list) | |||||
latest |= seen | |||||
latest = seen | latest_frozen | |||||
logger.info( | logger.info( | ||||
"Requesting missing events between %s and %s", | "Requesting missing events between %s and %s", | ||||
@@ -1976,8 +1975,7 @@ class FederationEventHandler: | |||||
# partial and full state and may not be accurate. | # partial and full state and may not be accurate. | ||||
return | return | ||||
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) | |||||
extrem_ids = set(extrem_ids_list) | |||||
extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id) | |||||
prev_event_ids = set(event.prev_event_ids()) | prev_event_ids = set(event.prev_event_ids()) | ||||
if extrem_ids == prev_event_ids: | if extrem_ids == prev_event_ids: | ||||
@@ -19,6 +19,7 @@ import logging | |||||
from collections import deque | from collections import deque | ||||
from typing import ( | from typing import ( | ||||
TYPE_CHECKING, | TYPE_CHECKING, | ||||
AbstractSet, | |||||
Any, | Any, | ||||
Awaitable, | Awaitable, | ||||
Callable, | Callable, | ||||
@@ -618,7 +619,7 @@ class EventsPersistenceStorageController: | |||||
) | ) | ||||
for room_id, ev_ctx_rm in events_by_room.items(): | for room_id, ev_ctx_rm in events_by_room.items(): | ||||
latest_event_ids = set( | |||||
latest_event_ids = ( | |||||
await self.main_store.get_latest_event_ids_in_room(room_id) | await self.main_store.get_latest_event_ids_in_room(room_id) | ||||
) | ) | ||||
new_latest_event_ids = await self._calculate_new_extremities( | new_latest_event_ids = await self._calculate_new_extremities( | ||||
@@ -740,7 +741,7 @@ class EventsPersistenceStorageController: | |||||
self, | self, | ||||
room_id: str, | room_id: str, | ||||
event_contexts: List[Tuple[EventBase, EventContext]], | event_contexts: List[Tuple[EventBase, EventContext]], | ||||
latest_event_ids: Collection[str], | |||||
latest_event_ids: AbstractSet[str], | |||||
) -> Set[str]: | ) -> Set[str]: | ||||
"""Calculates the new forward extremities for a room given events to | """Calculates the new forward extremities for a room given events to | ||||
persist. | persist. | ||||
@@ -758,8 +759,6 @@ class EventsPersistenceStorageController: | |||||
and not event.internal_metadata.is_soft_failed() | and not event.internal_metadata.is_soft_failed() | ||||
] | ] | ||||
latest_event_ids = set(latest_event_ids) | |||||
# start with the existing forward extremities | # start with the existing forward extremities | ||||
result = set(latest_event_ids) | result = set(latest_event_ids) | ||||
@@ -798,7 +797,7 @@ class EventsPersistenceStorageController: | |||||
self, | self, | ||||
room_id: str, | room_id: str, | ||||
events_context: List[Tuple[EventBase, EventContext]], | events_context: List[Tuple[EventBase, EventContext]], | ||||
old_latest_event_ids: Set[str], | |||||
old_latest_event_ids: AbstractSet[str], | |||||
new_latest_event_ids: Set[str], | new_latest_event_ids: Set[str], | ||||
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]: | ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]: | ||||
"""Calculate the current state dict after adding some new events to | """Calculate the current state dict after adding some new events to | ||||
@@ -19,6 +19,7 @@ from typing import ( | |||||
TYPE_CHECKING, | TYPE_CHECKING, | ||||
Collection, | Collection, | ||||
Dict, | Dict, | ||||
FrozenSet, | |||||
Iterable, | Iterable, | ||||
List, | List, | ||||
Optional, | Optional, | ||||
@@ -47,7 +48,7 @@ from synapse.storage.database import ( | |||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | from synapse.storage.databases.main.events_worker import EventsWorkerStore | ||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore | from synapse.storage.databases.main.signatures import SignatureWorkerStore | ||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine | from synapse.storage.engines import PostgresEngine, Sqlite3Engine | ||||
from synapse.types import JsonDict, StrCollection, StrSequence | |||||
from synapse.types import JsonDict, StrCollection | |||||
from synapse.util import json_encoder | from synapse.util import json_encoder | ||||
from synapse.util.caches.descriptors import cached | from synapse.util.caches.descriptors import cached | ||||
from synapse.util.caches.lrucache import LruCache | from synapse.util.caches.lrucache import LruCache | ||||
@@ -1179,13 +1180,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||||
) | ) | ||||
@cached(max_entries=5000, iterable=True) | @cached(max_entries=5000, iterable=True) | ||||
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence: | |||||
return await self.db_pool.simple_select_onecol( | |||||
async def get_latest_event_ids_in_room(self, room_id: str) -> FrozenSet[str]: | |||||
event_ids = await self.db_pool.simple_select_onecol( | |||||
table="event_forward_extremities", | table="event_forward_extremities", | ||||
keyvalues={"room_id": room_id}, | keyvalues={"room_id": room_id}, | ||||
retcol="event_id", | retcol="event_id", | ||||
desc="get_latest_event_ids_in_room", | desc="get_latest_event_ids_in_room", | ||||
) | ) | ||||
return frozenset(event_ids) | |||||
async def get_min_depth(self, room_id: str) -> Optional[int]: | async def get_min_depth(self, room_id: str) -> Optional[int]: | ||||
"""For the given room, get the minimum depth we have seen for it.""" | """For the given room, get the minimum depth we have seen for it.""" | ||||
@@ -222,7 +222,7 @@ class PersistEventsStore: | |||||
for room_id, latest_event_ids in new_forward_extremities.items(): | for room_id, latest_event_ids in new_forward_extremities.items(): | ||||
self.store.get_latest_event_ids_in_room.prefill( | self.store.get_latest_event_ids_in_room.prefill( | ||||
(room_id,), list(latest_event_ids) | |||||
(room_id,), frozenset(latest_event_ids) | |||||
) | ) | ||||
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: | async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: | ||||
@@ -1858,7 +1858,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
event = self.get_success( | event = self.get_success( | ||||
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) | |||||
builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None) | |||||
) | ) | ||||
self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) | self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) | ||||
@@ -90,7 +90,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): | |||||
def test_get_latest_event_ids_in_room(self) -> None: | def test_get_latest_event_ids_in_room(self) -> None: | ||||
create = self.persist(type="m.room.create", key="", creator=USER_ID) | create = self.persist(type="m.room.create", key="", creator=USER_ID) | ||||
self.replicate() | self.replicate() | ||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) | |||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), {create.event_id}) | |||||
join = self.persist( | join = self.persist( | ||||
type="m.room.member", | type="m.room.member", | ||||
@@ -99,7 +99,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): | |||||
prev_events=[(create.event_id, {})], | prev_events=[(create.event_id, {})], | ||||
) | ) | ||||
self.replicate() | self.replicate() | ||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) | |||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), {join.event_id}) | |||||
def test_redactions(self) -> None: | def test_redactions(self) -> None: | ||||
self.persist(type="m.room.create", key="", creator=USER_ID) | self.persist(type="m.room.create", key="", creator=USER_ID) | ||||
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Any, List, Optional, Sequence | |||||
from typing import Any, List, Optional | |||||
from twisted.test.proto_helpers import MemoryReactor | from twisted.test.proto_helpers import MemoryReactor | ||||
@@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase): | |||||
) | ) | ||||
# this is the point in the DAG where we make a fork | # this is the point in the DAG where we make a fork | ||||
fork_point: Sequence[str] = self.get_success( | |||||
fork_point = self.get_success( | |||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) | self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) | ||||
) | ) | ||||
@@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase): | |||||
) | ) | ||||
# this is the point in the DAG where we make a fork | # this is the point in the DAG where we make a fork | ||||
fork_point: Sequence[str] = self.get_success( | |||||
fork_point = self.get_success( | |||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) | self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) | ||||
) | ) | ||||
@@ -316,14 +316,14 @@ class EventsStreamTestCase(BaseStreamTestCase): | |||||
self.test_handler.received_rdata_rows.clear() | self.test_handler.received_rdata_rows.clear() | ||||
# now roll back all that state by de-modding the users | # now roll back all that state by de-modding the users | ||||
prev_events = fork_point | |||||
prev_events = list(fork_point) | |||||
pl_events = [] | pl_events = [] | ||||
for u in user_ids: | for u in user_ids: | ||||
pls["users"][u] = 0 | pls["users"][u] = 0 | ||||
e = self.get_success( | e = self.get_success( | ||||
inject_event( | inject_event( | ||||
self.hs, | self.hs, | ||||
prev_event_ids=list(prev_events), | |||||
prev_event_ids=prev_events, | |||||
type=EventTypes.PowerLevels, | type=EventTypes.PowerLevels, | ||||
state_key="", | state_key="", | ||||
sender=self.user_id, | sender=self.user_id, | ||||
@@ -261,7 +261,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): | |||||
builder = factory.for_room_version(room_version, event_dict) | builder = factory.for_room_version(room_version, event_dict) | ||||
join_event = self.get_success( | join_event = self.get_success( | ||||
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) | |||||
builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None) | |||||
) | ) | ||||
self.get_success(federation.on_send_membership_event(remote_server, join_event)) | self.get_success(federation.on_send_membership_event(remote_server, join_event)) | ||||
@@ -120,7 +120,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||||
self.store.get_latest_event_ids_in_room(self.room_id) | self.store.get_latest_event_ids_in_room(self.room_id) | ||||
) | ) | ||||
self.assertEqual(latest_event_ids, [event_id_4]) | |||||
self.assertEqual(latest_event_ids, {event_id_4}) | |||||
def test_basic_cleanup(self) -> None: | def test_basic_cleanup(self) -> None: | ||||
"""Test that extremities are correctly calculated in the presence of | """Test that extremities are correctly calculated in the presence of | ||||
@@ -147,7 +147,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||||
latest_event_ids = self.get_success( | latest_event_ids = self.get_success( | ||||
self.store.get_latest_event_ids_in_room(self.room_id) | self.store.get_latest_event_ids_in_room(self.room_id) | ||||
) | ) | ||||
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) | |||||
self.assertEqual(latest_event_ids, {event_id_a, event_id_b}) | |||||
# Run the background update and check it did the right thing | # Run the background update and check it did the right thing | ||||
self.run_background_update() | self.run_background_update() | ||||
@@ -155,7 +155,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||||
latest_event_ids = self.get_success( | latest_event_ids = self.get_success( | ||||
self.store.get_latest_event_ids_in_room(self.room_id) | self.store.get_latest_event_ids_in_room(self.room_id) | ||||
) | ) | ||||
self.assertEqual(latest_event_ids, [event_id_b]) | |||||
self.assertEqual(latest_event_ids, {event_id_b}) | |||||
def test_chain_of_fail_cleanup(self) -> None: | def test_chain_of_fail_cleanup(self) -> None: | ||||
"""Test that extremities are correctly calculated in the presence of | """Test that extremities are correctly calculated in the presence of | ||||
@@ -185,7 +185,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||||
latest_event_ids = self.get_success( | latest_event_ids = self.get_success( | ||||
self.store.get_latest_event_ids_in_room(self.room_id) | self.store.get_latest_event_ids_in_room(self.room_id) | ||||
) | ) | ||||
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) | |||||
self.assertEqual(latest_event_ids, {event_id_a, event_id_b}) | |||||
# Run the background update and check it did the right thing | # Run the background update and check it did the right thing | ||||
self.run_background_update() | self.run_background_update() | ||||
@@ -193,7 +193,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||||
latest_event_ids = self.get_success( | latest_event_ids = self.get_success( | ||||
self.store.get_latest_event_ids_in_room(self.room_id) | self.store.get_latest_event_ids_in_room(self.room_id) | ||||
) | ) | ||||
self.assertEqual(latest_event_ids, [event_id_b]) | |||||
self.assertEqual(latest_event_ids, {event_id_b}) | |||||
def test_forked_graph_cleanup(self) -> None: | def test_forked_graph_cleanup(self) -> None: | ||||
r"""Test that extremities are correctly calculated in the presence of | r"""Test that extremities are correctly calculated in the presence of | ||||
@@ -240,7 +240,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||||
latest_event_ids = self.get_success( | latest_event_ids = self.get_success( | ||||
self.store.get_latest_event_ids_in_room(self.room_id) | self.store.get_latest_event_ids_in_room(self.room_id) | ||||
) | ) | ||||
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c}) | |||||
self.assertEqual(latest_event_ids, {event_id_a, event_id_b, event_id_c}) | |||||
# Run the background update and check it did the right thing | # Run the background update and check it did the right thing | ||||
self.run_background_update() | self.run_background_update() | ||||
@@ -248,7 +248,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): | |||||
latest_event_ids = self.get_success( | latest_event_ids = self.get_success( | ||||
self.store.get_latest_event_ids_in_room(self.room_id) | self.store.get_latest_event_ids_in_room(self.room_id) | ||||
) | ) | ||||
self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c}) | |||||
self.assertEqual(latest_event_ids, {event_id_b, event_id_c}) | |||||
class CleanupExtremDummyEventsTestCase(HomeserverTestCase): | class CleanupExtremDummyEventsTestCase(HomeserverTestCase): | ||||
@@ -51,9 +51,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
self.store = self.hs.get_datastores().main | self.store = self.hs.get_datastores().main | ||||
# Figure out what the most recent event is | # Figure out what the most recent event is | ||||
most_recent = self.get_success( | |||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) | |||||
)[0] | |||||
most_recent = next( | |||||
iter( | |||||
self.get_success( | |||||
self.hs.get_datastores().main.get_latest_event_ids_in_room( | |||||
self.room_id | |||||
) | |||||
) | |||||
) | |||||
) | |||||
join_event = make_event_from_dict( | join_event = make_event_from_dict( | ||||
{ | { | ||||
@@ -100,8 +106,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
# Make sure we actually joined the room | # Make sure we actually joined the room | ||||
self.assertEqual( | self.assertEqual( | ||||
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0], | |||||
"$join:test.serv", | |||||
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)), | |||||
{"$join:test.serv"}, | |||||
) | ) | ||||
def test_cant_hide_direct_ancestors(self) -> None: | def test_cant_hide_direct_ancestors(self) -> None: | ||||
@@ -127,9 +133,11 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
self.http_client.post_json = post_json | self.http_client.post_json = post_json | ||||
# Figure out what the most recent event is | # Figure out what the most recent event is | ||||
most_recent = self.get_success( | |||||
self.store.get_latest_event_ids_in_room(self.room_id) | |||||
)[0] | |||||
most_recent = next( | |||||
iter( | |||||
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) | |||||
) | |||||
) | |||||
# Now lie about an event | # Now lie about an event | ||||
lying_event = make_event_from_dict( | lying_event = make_event_from_dict( | ||||
@@ -165,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
# Make sure the invalid event isn't there | # Make sure the invalid event isn't there | ||||
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) | extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) | ||||
self.assertEqual(extrem[0], "$join:test.serv") | |||||
self.assertEqual(extrem, {"$join:test.serv"}) | |||||
def test_retry_device_list_resync(self) -> None: | def test_retry_device_list_resync(self) -> None: | ||||
"""Tests that device lists are marked as stale if they couldn't be synced, and | """Tests that device lists are marked as stale if they couldn't be synced, and | ||||