It's important that collections returned from `@cached` methods are not modified, otherwise future retrievals from the cache will return the modified collection. This applies to the return values from `@cached` methods and the values inside the dictionaries returned by `@cachedList` methods. It's not necessary for the dictionaries returned by `@cachedList` methods themselves to be read-only. Signed-off-by: Sean Quah <seanq@matrix.org> Co-authored-by: David Robertson <davidr@element.io>tags/v1.78.0rc1
@@ -0,0 +1 @@ | |||
Re-type hint some collections as read-only. |
@@ -15,7 +15,7 @@ import logging | |||
import math | |||
import resource | |||
import sys | |||
from typing import TYPE_CHECKING, List, Sized, Tuple | |||
from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple | |||
from prometheus_client import Gauge | |||
@@ -194,7 +194,7 @@ def start_phone_stats_home(hs: "HomeServer") -> None: | |||
@wrap_as_background_process("generate_monthly_active_users") | |||
async def generate_monthly_active_users() -> None: | |||
current_mau_count = 0 | |||
current_mau_count_by_service = {} | |||
current_mau_count_by_service: Mapping[str, int] = {} | |||
reserved_users: Sized = () | |||
store = hs.get_datastores().main | |||
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: | |||
@@ -13,7 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, List | |||
from typing import Any, Collection | |||
from matrix_common.regex import glob_to_regex | |||
@@ -70,7 +70,7 @@ class RoomDirectoryConfig(Config): | |||
return False | |||
def is_publishing_room_allowed( | |||
self, user_id: str, room_id: str, aliases: List[str] | |||
self, user_id: str, room_id: str, aliases: Collection[str] | |||
) -> bool: | |||
"""Checks if the given user is allowed to publish the room | |||
@@ -122,7 +122,7 @@ class _RoomDirectoryRule: | |||
except Exception as e: | |||
raise ConfigError("Failed to parse glob into regex") from e | |||
def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool: | |||
def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool: | |||
"""Tests if this rule matches the given user_id, room_id and aliases. | |||
Args: | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union | |||
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union | |||
import attr | |||
from signedjson.types import SigningKey | |||
@@ -103,7 +103,7 @@ class EventBuilder: | |||
async def build( | |||
self, | |||
prev_event_ids: List[str], | |||
prev_event_ids: Collection[str], | |||
auth_event_ids: Optional[List[str]], | |||
depth: Optional[int] = None, | |||
) -> EventBase: | |||
@@ -136,7 +136,7 @@ class EventBuilder: | |||
format_version = self.room_version.event_format | |||
# The types of auth/prev events changes between event versions. | |||
prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] | |||
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]] | |||
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] | |||
if format_version == EventFormatVersions.ROOM_V1_V2: | |||
auth_events = await self._store.add_event_hashes(auth_event_ids) | |||
@@ -23,6 +23,7 @@ from typing import ( | |||
Collection, | |||
Dict, | |||
List, | |||
Mapping, | |||
Optional, | |||
Tuple, | |||
Union, | |||
@@ -1512,7 +1513,7 @@ class FederationHandlerRegistry: | |||
def _get_event_ids_for_partial_state_join( | |||
join_event: EventBase, | |||
prev_state_ids: StateMap[str], | |||
summary: Dict[str, MemberSummary], | |||
summary: Mapping[str, MemberSummary], | |||
) -> Collection[str]: | |||
"""Calculate state to be returned in a partial_state send_join | |||
@@ -14,7 +14,7 @@ | |||
import logging | |||
import string | |||
from typing import TYPE_CHECKING, Iterable, List, Optional | |||
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence | |||
from typing_extensions import Literal | |||
@@ -486,7 +486,7 @@ class DirectoryHandler: | |||
) | |||
if canonical_alias: | |||
# Ensure we do not mutate room_aliases. | |||
room_aliases = room_aliases + [canonical_alias] | |||
room_aliases = list(room_aliases) + [canonical_alias] | |||
if not self.config.roomdirectory.is_publishing_room_allowed( | |||
user_id, room_id, room_aliases | |||
@@ -529,7 +529,7 @@ class DirectoryHandler: | |||
async def get_aliases_for_room( | |||
self, requester: Requester, room_id: str | |||
) -> List[str]: | |||
) -> Sequence[str]: | |||
""" | |||
Get a list of the aliases that currently point to this room on this server | |||
""" | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple | |||
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple | |||
from synapse.api.constants import EduTypes, ReceiptTypes | |||
from synapse.appservice import ApplicationService | |||
@@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): | |||
@staticmethod | |||
def filter_out_private_receipts( | |||
rooms: List[JsonDict], user_id: str | |||
rooms: Sequence[JsonDict], user_id: str | |||
) -> List[JsonDict]: | |||
""" | |||
Filters a list of serialized receipts (as returned by /sync and /initialSync) | |||
@@ -1928,6 +1928,6 @@ class RoomShutdownHandler: | |||
return { | |||
"kicked_users": kicked_users, | |||
"failed_to_kick_users": failed_to_kick_users, | |||
"local_aliases": aliases_for_room, | |||
"local_aliases": list(aliases_for_room), | |||
"new_room_id": new_room_id, | |||
} |
@@ -1519,7 +1519,7 @@ class SyncHandler: | |||
one_time_keys_count = await self.store.count_e2e_one_time_keys( | |||
user_id, device_id | |||
) | |||
unused_fallback_key_types = ( | |||
unused_fallback_key_types = list( | |||
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) | |||
) | |||
@@ -2301,7 +2301,7 @@ class SyncHandler: | |||
sync_result_builder: "SyncResultBuilder", | |||
room_builder: "RoomSyncResultBuilder", | |||
ephemeral: List[JsonDict], | |||
tags: Optional[Dict[str, Dict[str, Any]]], | |||
tags: Optional[Mapping[str, Mapping[str, Any]]], | |||
account_data: Mapping[str, JsonDict], | |||
always_include: bool = False, | |||
) -> None: | |||
@@ -22,6 +22,7 @@ from typing import ( | |||
List, | |||
Mapping, | |||
Optional, | |||
Sequence, | |||
Set, | |||
Tuple, | |||
Union, | |||
@@ -149,7 +150,7 @@ class BulkPushRuleEvaluator: | |||
# little, we can skip fetching a huge number of push rules in large rooms. | |||
# This helps make joins and leaves faster. | |||
if event.type == EventTypes.Member: | |||
local_users = [] | |||
local_users: Sequence[str] = [] | |||
# We never notify a user about their own actions. This is enforced in | |||
# `_action_for_event_by_user` in the loop over `rules_by_user`, but we | |||
# do the same check here to avoid unnecessary DB queries. | |||
@@ -184,7 +185,6 @@ class BulkPushRuleEvaluator: | |||
if event.type == EventTypes.Member and event.membership == Membership.INVITE: | |||
invited = event.state_key | |||
if invited and self.hs.is_mine_id(invited) and invited not in local_users: | |||
local_users = list(local_users) | |||
local_users.append(invited) | |||
if not local_users: | |||
@@ -226,7 +226,7 @@ class StateHandler: | |||
return await ret.get_state(self._state_storage_controller, state_filter) | |||
async def get_current_user_ids_in_room( | |||
self, room_id: str, latest_event_ids: List[str] | |||
self, room_id: str, latest_event_ids: Collection[str] | |||
) -> Set[str]: | |||
""" | |||
Get the users IDs who are currently in a room. | |||
@@ -14,6 +14,7 @@ | |||
import logging | |||
from typing import ( | |||
TYPE_CHECKING, | |||
AbstractSet, | |||
Any, | |||
Awaitable, | |||
Callable, | |||
@@ -23,7 +24,6 @@ from typing import ( | |||
List, | |||
Mapping, | |||
Optional, | |||
Set, | |||
Tuple, | |||
) | |||
@@ -527,7 +527,7 @@ class StateStorageController: | |||
) | |||
return state_map.get(key) | |||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: | |||
async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]: | |||
"""Get current hosts in room based on current state. | |||
Blocks until we have full state for the given room. This only happens for rooms | |||
@@ -584,7 +584,7 @@ class StateStorageController: | |||
async def get_users_in_room_with_profiles( | |||
self, room_id: str | |||
) -> Dict[str, ProfileInfo]: | |||
) -> Mapping[str, ProfileInfo]: | |||
""" | |||
Get the current users in the room with their profiles. | |||
If the room is currently partial-stated, this will block until the room has | |||
@@ -240,7 +240,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
@cached(num_args=2, tree=True) | |||
async def get_account_data_for_room( | |||
self, user_id: str, room_id: str | |||
) -> Dict[str, JsonDict]: | |||
) -> Mapping[str, JsonDict]: | |||
"""Get all the client account_data for a user for a room. | |||
Args: | |||
@@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): | |||
room_id: str, | |||
app_service: "ApplicationService", | |||
cache_context: _CacheContext, | |||
) -> List[str]: | |||
) -> Sequence[str]: | |||
""" | |||
Get all users in a room that the appservice controls. | |||
@@ -21,6 +21,7 @@ from typing import ( | |||
Dict, | |||
Iterable, | |||
List, | |||
Mapping, | |||
Optional, | |||
Set, | |||
Tuple, | |||
@@ -202,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
def get_device_stream_token(self) -> int: | |||
return self._device_list_id_gen.get_current_token() | |||
async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: | |||
async def count_devices_by_users( | |||
self, user_ids: Optional[Collection[str]] = None | |||
) -> int: | |||
"""Retrieve number of all devices of given users. | |||
Only returns number of devices that are not marked as hidden. | |||
@@ -213,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
""" | |||
def count_devices_by_users_txn( | |||
txn: LoggingTransaction, user_ids: List[str] | |||
txn: LoggingTransaction, user_ids: Collection[str] | |||
) -> int: | |||
sql = """ | |||
SELECT count(*) | |||
@@ -747,7 +750,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
@cancellable | |||
async def get_user_devices_from_cache( | |||
self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] | |||
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: | |||
) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]: | |||
"""Get the devices (and keys if any) for remote users from the cache. | |||
Args: | |||
@@ -775,16 +778,18 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache | |||
# First fetch all the users which all devices are to be returned. | |||
results: Dict[str, Dict[str, JsonDict]] = {} | |||
results: Dict[str, Mapping[str, JsonDict]] = {} | |||
for user_id in user_ids: | |||
if user_id in user_ids_in_cache: | |||
results[user_id] = await self.get_cached_devices_for_user(user_id) | |||
# Then fetch all device-specific requests, but skip users we've already | |||
# fetched all devices for. | |||
device_specific_results: Dict[str, Dict[str, JsonDict]] = {} | |||
for user_id, device_id in user_and_device_ids: | |||
if user_id in user_ids_in_cache and user_id not in user_ids: | |||
device = await self._get_cached_user_device(user_id, device_id) | |||
results.setdefault(user_id, {})[device_id] = device | |||
device_specific_results.setdefault(user_id, {})[device_id] = device | |||
results.update(device_specific_results) | |||
set_tag("in_cache", str(results)) | |||
set_tag("not_in_cache", str(user_ids_not_in_cache)) | |||
@@ -802,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||
return db_to_json(content) | |||
@cached() | |||
async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: | |||
async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]: | |||
devices = await self.db_pool.simple_select_list( | |||
table="device_lists_remote_cache", | |||
keyvalues={"user_id": user_id}, | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Iterable, List, Optional, Tuple | |||
from typing import Iterable, List, Optional, Sequence, Tuple | |||
import attr | |||
@@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore): | |||
) | |||
@cached(max_entries=5000) | |||
async def get_aliases_for_room(self, room_id: str) -> List[str]: | |||
async def get_aliases_for_room(self, room_id: str) -> Sequence[str]: | |||
return await self.db_pool.simple_select_onecol( | |||
"room_aliases", | |||
{"room_id": room_id}, | |||
@@ -20,7 +20,9 @@ from typing import ( | |||
Dict, | |||
Iterable, | |||
List, | |||
Mapping, | |||
Optional, | |||
Sequence, | |||
Tuple, | |||
Union, | |||
cast, | |||
@@ -691,7 +693,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
@cached(max_entries=10000) | |||
async def get_e2e_unused_fallback_key_types( | |||
self, user_id: str, device_id: str | |||
) -> List[str]: | |||
) -> Sequence[str]: | |||
"""Returns the fallback key types that have an unused key. | |||
Args: | |||
@@ -731,7 +733,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
return user_keys.get(key_type) | |||
@cached(num_args=1) | |||
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: | |||
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]: | |||
"""Dummy function. Only used to make a cache for | |||
_get_bare_e2e_cross_signing_keys_bulk. | |||
""" | |||
@@ -744,7 +746,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
) | |||
async def _get_bare_e2e_cross_signing_keys_bulk( | |||
self, user_ids: Iterable[str] | |||
) -> Dict[str, Optional[Dict[str, JsonDict]]]: | |||
) -> Dict[str, Optional[Mapping[str, JsonDict]]]: | |||
"""Returns the cross-signing keys for a set of users. The output of this | |||
function should be passed to _get_e2e_cross_signing_signatures_txn if | |||
the signatures for the calling user need to be fetched. | |||
@@ -765,7 +767,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
) | |||
# The `Optional` comes from the `@cachedList` decorator. | |||
return cast(Dict[str, Optional[Dict[str, JsonDict]]], result) | |||
return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result) | |||
def _get_bare_e2e_cross_signing_keys_bulk_txn( | |||
self, | |||
@@ -924,7 +926,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
@cancellable | |||
async def get_e2e_cross_signing_keys_bulk( | |||
self, user_ids: List[str], from_user_id: Optional[str] = None | |||
) -> Dict[str, Optional[Dict[str, JsonDict]]]: | |||
) -> Dict[str, Optional[Mapping[str, JsonDict]]]: | |||
"""Returns the cross-signing keys for a set of users. | |||
Args: | |||
@@ -940,11 +942,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) | |||
if from_user_id: | |||
result = await self.db_pool.runInteraction( | |||
"get_e2e_cross_signing_signatures", | |||
self._get_e2e_cross_signing_signatures_txn, | |||
result, | |||
from_user_id, | |||
result = cast( | |||
Dict[str, Optional[Mapping[str, JsonDict]]], | |||
await self.db_pool.runInteraction( | |||
"get_e2e_cross_signing_signatures", | |||
self._get_e2e_cross_signing_signatures_txn, | |||
result, | |||
from_user_id, | |||
), | |||
) | |||
return result | |||
@@ -22,6 +22,7 @@ from typing import ( | |||
Iterable, | |||
List, | |||
Optional, | |||
Sequence, | |||
Set, | |||
Tuple, | |||
cast, | |||
@@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
room_id, | |||
) | |||
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: | |||
async def get_max_depth_of( | |||
self, event_ids: Collection[str] | |||
) -> Tuple[Optional[str], int]: | |||
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs | |||
Args: | |||
@@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
) | |||
@cached(max_entries=5000, iterable=True) | |||
async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: | |||
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]: | |||
return await self.db_pool.simple_select_onecol( | |||
table="event_forward_extremities", | |||
keyvalues={"room_id": room_id}, | |||
@@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
@cancellable | |||
async def get_forward_extremities_for_room_at_stream_ordering( | |||
self, room_id: str, stream_ordering: int | |||
) -> List[str]: | |||
) -> Sequence[str]: | |||
"""For a given room_id and stream_ordering, return the forward | |||
extremeties of the room at that point in "time". | |||
@@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
@cached(max_entries=5000, num_args=2) | |||
async def _get_forward_extremeties_for_room( | |||
self, room_id: str, stream_ordering: int | |||
) -> List[str]: | |||
) -> Sequence[str]: | |||
"""For a given room_id and stream_ordering, return the forward | |||
extremeties of the room at that point in "time". | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast | |||
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast | |||
from synapse.metrics.background_process_metrics import wrap_as_background_process | |||
from synapse.storage.database import ( | |||
@@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): | |||
return await self.db_pool.runInteraction("count_users", _count_users) | |||
@cached(num_args=0) | |||
async def get_monthly_active_count_by_service(self) -> Dict[str, int]: | |||
async def get_monthly_active_count_by_service(self) -> Mapping[str, int]: | |||
"""Generates current count of monthly active users broken down by service. | |||
A service is typically an appservice but also includes native matrix users. | |||
Since the `monthly_active_users` table is populated from the `user_ips` table | |||
@@ -21,7 +21,9 @@ from typing import ( | |||
Dict, | |||
Iterable, | |||
List, | |||
Mapping, | |||
Optional, | |||
Sequence, | |||
Tuple, | |||
cast, | |||
) | |||
@@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
async def get_linearized_receipts_for_room( | |||
self, room_id: str, to_key: int, from_key: Optional[int] = None | |||
) -> List[dict]: | |||
) -> Sequence[JsonDict]: | |||
"""Get receipts for a single room for sending to clients. | |||
Args: | |||
@@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
@cached(tree=True) | |||
async def _get_linearized_receipts_for_room( | |||
self, room_id: str, to_key: int, from_key: Optional[int] = None | |||
) -> List[JsonDict]: | |||
) -> Sequence[JsonDict]: | |||
"""See get_linearized_receipts_for_room""" | |||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: | |||
@@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
) | |||
async def _get_linearized_receipts_for_rooms( | |||
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None | |||
) -> Dict[str, List[JsonDict]]: | |||
) -> Dict[str, Sequence[JsonDict]]: | |||
if not room_ids: | |||
return {} | |||
@@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
) | |||
async def get_linearized_receipts_for_all_rooms( | |||
self, to_key: int, from_key: Optional[int] = None | |||
) -> Dict[str, JsonDict]: | |||
) -> Mapping[str, JsonDict]: | |||
"""Get receipts for all rooms between two stream_ids, up | |||
to a limit of the latest 100 read receipts. | |||
@@ -16,7 +16,7 @@ | |||
import logging | |||
import random | |||
import re | |||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast | |||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast | |||
import attr | |||
@@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): | |||
) | |||
@cached() | |||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: | |||
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]: | |||
"""Deprecated: use get_userinfo_by_id instead""" | |||
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: | |||
@@ -22,6 +22,7 @@ from typing import ( | |||
List, | |||
Mapping, | |||
Optional, | |||
Sequence, | |||
Set, | |||
Tuple, | |||
Union, | |||
@@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||
direction: Direction = Direction.BACKWARDS, | |||
from_token: Optional[StreamToken] = None, | |||
to_token: Optional[StreamToken] = None, | |||
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: | |||
) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]: | |||
"""Get a list of relations for an event, ordered by topological ordering. | |||
Args: | |||
@@ -397,7 +398,9 @@ class RelationsWorkerStore(SQLBaseStore): | |||
return result is not None | |||
@cached() | |||
async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]: | |||
async def get_aggregation_groups_for_event( | |||
self, event_id: str | |||
) -> Sequence[JsonDict]: | |||
raise NotImplementedError() | |||
@cachedList( | |||
@@ -24,6 +24,7 @@ from typing import ( | |||
List, | |||
Mapping, | |||
Optional, | |||
Sequence, | |||
Set, | |||
Tuple, | |||
Union, | |||
@@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
return self._known_servers_count | |||
@cached(max_entries=100000, iterable=True) | |||
async def get_users_in_room(self, room_id: str) -> List[str]: | |||
async def get_users_in_room(self, room_id: str) -> Sequence[str]: | |||
"""Returns a list of users in the room. | |||
Will return inaccurate results for rooms with partial state, since the state for | |||
@@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
) | |||
@cached() | |||
def get_user_in_room_with_profile( | |||
self, room_id: str, user_id: str | |||
) -> Dict[str, ProfileInfo]: | |||
def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo: | |||
raise NotImplementedError() | |||
@cachedList( | |||
@@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
@cached(max_entries=100000, iterable=True) | |||
async def get_users_in_room_with_profiles( | |||
self, room_id: str | |||
) -> Dict[str, ProfileInfo]: | |||
) -> Mapping[str, ProfileInfo]: | |||
"""Get a mapping from user ID to profile information for all users in a given room. | |||
The profile information comes directly from this room's `m.room.member` | |||
@@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
) | |||
@cached(max_entries=100000) | |||
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: | |||
async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]: | |||
"""Get the details of a room roughly suitable for use by the room | |||
summary extension to /sync. Useful when lazy loading room members. | |||
Args: | |||
@@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
@cached() | |||
async def get_invited_rooms_for_local_user( | |||
self, user_id: str | |||
) -> List[RoomsForUser]: | |||
) -> Sequence[RoomsForUser]: | |||
"""Get all the rooms the *local* user is invited to. | |||
Args: | |||
@@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
return results | |||
@cached(iterable=True) | |||
async def get_local_users_in_room(self, room_id: str) -> List[str]: | |||
async def get_local_users_in_room(self, room_id: str) -> Sequence[str]: | |||
""" | |||
Retrieves a list of the current roommembers who are local to the server. | |||
""" | |||
@@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
"""Returns the set of users who share a room with `user_id`""" | |||
room_ids = await self.get_rooms_for_user(user_id) | |||
user_who_share_room = set() | |||
user_who_share_room: Set[str] = set() | |||
for room_id in room_ids: | |||
user_ids = await self.get_users_in_room(room_id) | |||
user_who_share_room.update(user_ids) | |||
@@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
return True | |||
@cached(iterable=True, max_entries=10000) | |||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: | |||
async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]: | |||
"""Get current hosts in room based on current state.""" | |||
# First we check if we already have `get_users_in_room` in the cache, as | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Collection, Dict, List, Tuple | |||
from typing import Collection, Dict, List, Mapping, Tuple | |||
from unpaddedbase64 import encode_base64 | |||
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList | |||
class SignatureWorkerStore(EventsWorkerStore): | |||
@cached() | |||
def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]: | |||
def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]: | |||
# This is a dummy function to allow get_event_reference_hashes | |||
# to use its cache | |||
raise NotImplementedError() | |||
@@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore): | |||
) | |||
async def get_event_reference_hashes( | |||
self, event_ids: Collection[str] | |||
) -> Dict[str, Dict[str, bytes]]: | |||
) -> Mapping[str, Mapping[str, bytes]]: | |||
"""Get all hashes for given events. | |||
Args: | |||
@@ -15,7 +15,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Any, Dict, Iterable, List, Tuple, cast | |||
from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast | |||
from synapse.api.constants import AccountDataTypes | |||
from synapse.replication.tcp.streams import AccountDataStream | |||
@@ -32,7 +32,9 @@ logger = logging.getLogger(__name__) | |||
class TagsWorkerStore(AccountDataWorkerStore): | |||
@cached() | |||
async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]: | |||
async def get_tags_for_user( | |||
self, user_id: str | |||
) -> Mapping[str, Mapping[str, JsonDict]]: | |||
"""Get all the tags for a user. | |||
@@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore): | |||
async def get_updated_tags( | |||
self, user_id: str, stream_id: int | |||
) -> Dict[str, Dict[str, JsonDict]]: | |||
) -> Mapping[str, Mapping[str, JsonDict]]: | |||
"""Get all the tags for the rooms where the tags have changed since the | |||
given version | |||
@@ -16,9 +16,9 @@ import logging | |||
import re | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Dict, | |||
Iterable, | |||
List, | |||
Mapping, | |||
Optional, | |||
Sequence, | |||
Set, | |||
@@ -586,7 +586,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): | |||
) | |||
@cached() | |||
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]: | |||
async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: | |||
return await self.db_pool.simple_select_one( | |||
table="user_directory", | |||
keyvalues={"user_id": user_id}, | |||
@@ -11,7 +11,7 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import List | |||
from typing import List, Sequence | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -558,7 +558,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): | |||
def _check_invite_and_join_status( | |||
self, user_id: str, expected_invites: int, expected_memberships: int | |||
) -> List[RoomsForUser]: | |||
) -> Sequence[RoomsForUser]: | |||
"""Check invite and room membership status of a user. | |||
Args | |||