|
|
@@ -15,7 +15,6 @@ |
|
|
|
import logging |
|
|
|
from typing import ( |
|
|
|
TYPE_CHECKING, |
|
|
|
Callable, |
|
|
|
Collection, |
|
|
|
Dict, |
|
|
|
FrozenSet, |
|
|
@@ -52,7 +51,6 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain |
|
|
|
from synapse.util.async_helpers import Linearizer |
|
|
|
from synapse.util.caches import intern_string |
|
|
|
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList |
|
|
|
from synapse.util.cancellation import cancellable |
|
|
|
from synapse.util.iterutils import batch_iter |
|
|
|
from synapse.util.metrics import Measure |
|
|
|
|
|
|
@@ -600,58 +598,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
for room_id, instance, stream_id in txn |
|
|
|
) |
|
|
|
|
|
|
|
@cachedList( |
|
|
|
cached_method_name="get_rooms_for_user_with_stream_ordering", |
|
|
|
list_name="user_ids", |
|
|
|
) |
|
|
|
async def get_rooms_for_users_with_stream_ordering( |
|
|
|
self, user_ids: Collection[str] |
|
|
|
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: |
|
|
|
"""A batched version of `get_rooms_for_user_with_stream_ordering`. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Map from user_id to set of rooms that is currently in. |
|
|
|
""" |
|
|
|
return await self.db_pool.runInteraction( |
|
|
|
"get_rooms_for_users_with_stream_ordering", |
|
|
|
self._get_rooms_for_users_with_stream_ordering_txn, |
|
|
|
user_ids, |
|
|
|
) |
|
|
|
|
|
|
|
def _get_rooms_for_users_with_stream_ordering_txn( |
|
|
|
self, txn: LoggingTransaction, user_ids: Collection[str] |
|
|
|
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: |
|
|
|
|
|
|
|
clause, args = make_in_list_sql_clause( |
|
|
|
self.database_engine, |
|
|
|
"c.state_key", |
|
|
|
user_ids, |
|
|
|
) |
|
|
|
|
|
|
|
sql = f""" |
|
|
|
SELECT c.state_key, room_id, e.instance_name, e.stream_ordering |
|
|
|
FROM current_state_events AS c |
|
|
|
INNER JOIN events AS e USING (room_id, event_id) |
|
|
|
WHERE |
|
|
|
c.type = 'm.room.member' |
|
|
|
AND c.membership = ? |
|
|
|
AND {clause} |
|
|
|
""" |
|
|
|
|
|
|
|
txn.execute(sql, [Membership.JOIN] + args) |
|
|
|
|
|
|
|
result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { |
|
|
|
user_id: set() for user_id in user_ids |
|
|
|
} |
|
|
|
for user_id, room_id, instance, stream_id in txn: |
|
|
|
result[user_id].add( |
|
|
|
GetRoomsForUserWithStreamOrdering( |
|
|
|
room_id, PersistedEventPosition(instance, stream_id) |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
return {user_id: frozenset(v) for user_id, v in result.items()} |
|
|
|
|
|
|
|
async def get_users_server_still_shares_room_with( |
|
|
|
self, user_ids: Collection[str] |
|
|
|
) -> Set[str]: |
|
|
@@ -693,19 +639,68 @@ class RoomMemberWorkerStore(EventsWorkerStore): |
|
|
|
|
|
|
|
return {row[0] for row in txn} |
|
|
|
|
|
|
|
@cancellable |
|
|
|
async def get_rooms_for_user( |
|
|
|
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None |
|
|
|
) -> FrozenSet[str]: |
|
|
|
@cached(max_entries=500000, iterable=True) |
|
|
|
async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: |
|
|
|
"""Returns a set of room_ids the user is currently joined to. |
|
|
|
|
|
|
|
If a remote user only returns rooms this server is currently |
|
|
|
participating in. |
|
|
|
""" |
|
|
|
rooms = await self.get_rooms_for_user_with_stream_ordering( |
|
|
|
user_id, on_invalidate=on_invalidate |
|
|
|
rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate( |
|
|
|
(user_id,), |
|
|
|
None, |
|
|
|
update_metrics=False, |
|
|
|
) |
|
|
|
if rooms: |
|
|
|
return frozenset(r.room_id for r in rooms) |
|
|
|
|
|
|
|
room_ids = await self.db_pool.simple_select_onecol( |
|
|
|
table="current_state_events", |
|
|
|
keyvalues={ |
|
|
|
"type": EventTypes.Member, |
|
|
|
"membership": Membership.JOIN, |
|
|
|
"state_key": user_id, |
|
|
|
}, |
|
|
|
retcol="room_id", |
|
|
|
desc="get_rooms_for_user", |
|
|
|
) |
|
|
|
return frozenset(r.room_id for r in rooms) |
|
|
|
|
|
|
|
return frozenset(room_ids) |
|
|
|
|
|
|
|
@cachedList( |
|
|
|
cached_method_name="get_rooms_for_user", |
|
|
|
list_name="user_ids", |
|
|
|
) |
|
|
|
async def get_rooms_for_users( |
|
|
|
self, user_ids: Collection[str] |
|
|
|
) -> Dict[str, FrozenSet[str]]: |
|
|
|
"""A batched version of `get_rooms_for_user`. |
|
|
|
|
|
|
|
Returns: |
|
|
|
Map from user_id to set of rooms that is currently in. |
|
|
|
""" |
|
|
|
|
|
|
|
rows = await self.db_pool.simple_select_many_batch( |
|
|
|
table="current_state_events", |
|
|
|
column="state_key", |
|
|
|
iterable=user_ids, |
|
|
|
retcols=( |
|
|
|
"state_key", |
|
|
|
"room_id", |
|
|
|
), |
|
|
|
keyvalues={ |
|
|
|
"type": EventTypes.Member, |
|
|
|
"membership": Membership.JOIN, |
|
|
|
}, |
|
|
|
desc="get_rooms_for_users", |
|
|
|
) |
|
|
|
|
|
|
|
user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} |
|
|
|
|
|
|
|
for row in rows: |
|
|
|
user_rooms[row["state_key"]].add(row["room_id"]) |
|
|
|
|
|
|
|
return {key: frozenset(rooms) for key, rooms in user_rooms.items()} |
|
|
|
|
|
|
|
@cached(max_entries=10000) |
|
|
|
async def does_pair_of_users_share_a_room( |
|
|
|