Bladeren bron

Return read-only collections from `@cached` methods (#13755)

It's important that collections returned from `@cached` methods are not
modified, otherwise future retrievals from the cache will return the
modified collection.

This applies to the return values from `@cached` methods and the values
inside the dictionaries returned by `@cachedList` methods. It's not
necessary for the dictionaries returned by `@cachedList` methods
themselves to be read-only.

Signed-off-by: Sean Quah <seanq@matrix.org>
Co-authored-by: David Robertson <davidr@element.io>
tags/v1.78.0rc1
Sean Quah 1 jaar geleden
committed by GitHub
bovenliggende
commit
d0c713cc85
Geen bekende sleutel gevonden voor deze handtekening in de database GPG sleutel-ID: 4AEE18F83AFDEB23
27 gewijzigde bestanden met toevoegingen van 98 en 77 verwijderingen
  1. +1
    -0
      changelog.d/13755.misc
  2. +2
    -2
      synapse/app/phone_stats_home.py
  3. +3
    -3
      synapse/config/room_directory.py
  4. +3
    -3
      synapse/events/builder.py
  5. +2
    -1
      synapse/federation/federation_server.py
  6. +3
    -3
      synapse/handlers/directory.py
  7. +2
    -2
      synapse/handlers/receipts.py
  8. +1
    -1
      synapse/handlers/room.py
  9. +2
    -2
      synapse/handlers/sync.py
  10. +2
    -2
      synapse/push/bulk_push_rule_evaluator.py
  11. +1
    -1
      synapse/state/__init__.py
  12. +3
    -3
      synapse/storage/controllers/state.py
  13. +1
    -1
      synapse/storage/databases/main/account_data.py
  14. +1
    -1
      synapse/storage/databases/main/appservice.py
  15. +11
    -6
      synapse/storage/databases/main/devices.py
  16. +2
    -2
      synapse/storage/databases/main/directory.py
  17. +15
    -10
      synapse/storage/databases/main/end_to_end_keys.py
  18. +7
    -4
      synapse/storage/databases/main/event_federation.py
  19. +2
    -2
      synapse/storage/databases/main/monthly_active_users.py
  20. +6
    -4
      synapse/storage/databases/main/receipts.py
  21. +2
    -2
      synapse/storage/databases/main/registration.py
  22. +5
    -2
      synapse/storage/databases/main/relations.py
  23. +9
    -10
      synapse/storage/databases/main/roommember.py
  24. +3
    -3
      synapse/storage/databases/main/signatures.py
  25. +5
    -3
      synapse/storage/databases/main/tags.py
  26. +2
    -2
      synapse/storage/databases/main/user_directory.py
  27. +2
    -2
      tests/rest/admin/test_server_notice.py

+ 1
- 0
changelog.d/13755.misc Bestand weergeven

@@ -0,0 +1 @@
Re-type hint some collections as read-only.

+ 2
- 2
synapse/app/phone_stats_home.py Bestand weergeven

@@ -15,7 +15,7 @@ import logging
import math
import resource
import sys
from typing import TYPE_CHECKING, List, Sized, Tuple
from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple

from prometheus_client import Gauge

@@ -194,7 +194,7 @@ def start_phone_stats_home(hs: "HomeServer") -> None:
@wrap_as_background_process("generate_monthly_active_users")
async def generate_monthly_active_users() -> None:
current_mau_count = 0
current_mau_count_by_service = {}
current_mau_count_by_service: Mapping[str, int] = {}
reserved_users: Sized = ()
store = hs.get_datastores().main
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:


+ 3
- 3
synapse/config/room_directory.py Bestand weergeven

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List
from typing import Any, Collection

from matrix_common.regex import glob_to_regex

@@ -70,7 +70,7 @@ class RoomDirectoryConfig(Config):
return False

def is_publishing_room_allowed(
self, user_id: str, room_id: str, aliases: List[str]
self, user_id: str, room_id: str, aliases: Collection[str]
) -> bool:
"""Checks if the given user is allowed to publish the room

@@ -122,7 +122,7 @@ class _RoomDirectoryRule:
except Exception as e:
raise ConfigError("Failed to parse glob into regex") from e

def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool:
def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool:
"""Tests if this rule matches the given user_id, room_id and aliases.

Args:


+ 3
- 3
synapse/events/builder.py Bestand weergeven

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union

import attr
from signedjson.types import SigningKey
@@ -103,7 +103,7 @@ class EventBuilder:

async def build(
self,
prev_event_ids: List[str],
prev_event_ids: Collection[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
@@ -136,7 +136,7 @@ class EventBuilder:

format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)


+ 2
- 1
synapse/federation/federation_server.py Bestand weergeven

@@ -23,6 +23,7 @@ from typing import (
Collection,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
@@ -1512,7 +1513,7 @@ class FederationHandlerRegistry:
def _get_event_ids_for_partial_state_join(
join_event: EventBase,
prev_state_ids: StateMap[str],
summary: Dict[str, MemberSummary],
summary: Mapping[str, MemberSummary],
) -> Collection[str]:
"""Calculate state to be returned in a partial_state send_join



+ 3
- 3
synapse/handlers/directory.py Bestand weergeven

@@ -14,7 +14,7 @@

import logging
import string
from typing import TYPE_CHECKING, Iterable, List, Optional
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence

from typing_extensions import Literal

@@ -486,7 +486,7 @@ class DirectoryHandler:
)
if canonical_alias:
# Ensure we do not mutate room_aliases.
room_aliases = room_aliases + [canonical_alias]
room_aliases = list(room_aliases) + [canonical_alias]

if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_aliases
@@ -529,7 +529,7 @@ class DirectoryHandler:

async def get_aliases_for_room(
self, requester: Requester, room_id: str
) -> List[str]:
) -> Sequence[str]:
"""
Get a list of the aliases that currently point to this room on this server
"""


+ 2
- 2
synapse/handlers/receipts.py Bestand weergeven

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple

from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.appservice import ApplicationService
@@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):

@staticmethod
def filter_out_private_receipts(
rooms: List[JsonDict], user_id: str
rooms: Sequence[JsonDict], user_id: str
) -> List[JsonDict]:
"""
Filters a list of serialized receipts (as returned by /sync and /initialSync)


+ 1
- 1
synapse/handlers/room.py Bestand weergeven

@@ -1928,6 +1928,6 @@ class RoomShutdownHandler:
return {
"kicked_users": kicked_users,
"failed_to_kick_users": failed_to_kick_users,
"local_aliases": aliases_for_room,
"local_aliases": list(aliases_for_room),
"new_room_id": new_room_id,
}

+ 2
- 2
synapse/handlers/sync.py Bestand weergeven

@@ -1519,7 +1519,7 @@ class SyncHandler:
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = (
unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)

@@ -2301,7 +2301,7 @@ class SyncHandler:
sync_result_builder: "SyncResultBuilder",
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]],
tags: Optional[Mapping[str, Mapping[str, Any]]],
account_data: Mapping[str, JsonDict],
always_include: bool = False,
) -> None:


+ 2
- 2
synapse/push/bulk_push_rule_evaluator.py Bestand weergeven

@@ -22,6 +22,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
@@ -149,7 +150,7 @@ class BulkPushRuleEvaluator:
# little, we can skip fetching a huge number of push rules in large rooms.
# This helps make joins and leaves faster.
if event.type == EventTypes.Member:
local_users = []
local_users: Sequence[str] = []
# We never notify a user about their own actions. This is enforced in
# `_action_for_event_by_user` in the loop over `rules_by_user`, but we
# do the same check here to avoid unnecessary DB queries.
@@ -184,7 +185,6 @@ class BulkPushRuleEvaluator:
if event.type == EventTypes.Member and event.membership == Membership.INVITE:
invited = event.state_key
if invited and self.hs.is_mine_id(invited) and invited not in local_users:
local_users = list(local_users)
local_users.append(invited)

if not local_users:


+ 1
- 1
synapse/state/__init__.py Bestand weergeven

@@ -226,7 +226,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter)

async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: List[str]
self, room_id: str, latest_event_ids: Collection[str]
) -> Set[str]:
"""
Get the users IDs who are currently in a room.


+ 3
- 3
synapse/storage/controllers/state.py Bestand weergeven

@@ -14,6 +14,7 @@
import logging
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Awaitable,
Callable,
@@ -23,7 +24,6 @@ from typing import (
List,
Mapping,
Optional,
Set,
Tuple,
)

@@ -527,7 +527,7 @@ class StateStorageController:
)
return state_map.get(key)

async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state.

Blocks until we have full state for the given room. This only happens for rooms
@@ -584,7 +584,7 @@ class StateStorageController:

async def get_users_in_room_with_profiles(
self, room_id: str
) -> Dict[str, ProfileInfo]:
) -> Mapping[str, ProfileInfo]:
"""
Get the current users in the room with their profiles.
If the room is currently partial-stated, this will block until the room has


+ 1
- 1
synapse/storage/databases/main/account_data.py Bestand weergeven

@@ -240,7 +240,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
@cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
) -> Mapping[str, JsonDict]:
"""Get all the client account_data for a user for a room.

Args:


+ 1
- 1
synapse/storage/databases/main/appservice.py Bestand weergeven

@@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
room_id: str,
app_service: "ApplicationService",
cache_context: _CacheContext,
) -> List[str]:
) -> Sequence[str]:
"""
Get all users in a room that the appservice controls.



+ 11
- 6
synapse/storage/databases/main/devices.py Bestand weergeven

@@ -21,6 +21,7 @@ from typing import (
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
@@ -202,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()

async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
async def count_devices_by_users(
self, user_ids: Optional[Collection[str]] = None
) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.

@@ -213,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""

def count_devices_by_users_txn(
txn: LoggingTransaction, user_ids: List[str]
txn: LoggingTransaction, user_ids: Collection[str]
) -> int:
sql = """
SELECT count(*)
@@ -747,7 +750,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@cancellable
async def get_user_devices_from_cache(
self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.

Args:
@@ -775,16 +778,18 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache

# First fetch all the users which all devices are to be returned.
results: Dict[str, Dict[str, JsonDict]] = {}
results: Dict[str, Mapping[str, JsonDict]] = {}
for user_id in user_ids:
if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id)
# Then fetch all device-specific requests, but skip users we've already
# fetched all devices for.
device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in user_and_device_ids:
if user_id in user_ids_in_cache and user_id not in user_ids:
device = await self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
device_specific_results.setdefault(user_id, {})[device_id] = device
results.update(device_specific_results)

set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))
@@ -802,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return db_to_json(content)

@cached()
async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},


+ 2
- 2
synapse/storage/databases/main/directory.py Bestand weergeven

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Sequence, Tuple

import attr

@@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
)

@cached(max_entries=5000)
async def get_aliases_for_room(self, room_id: str) -> List[str]:
async def get_aliases_for_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},


+ 15
- 10
synapse/storage/databases/main/end_to_end_keys.py Bestand weergeven

@@ -20,7 +20,9 @@ from typing import (
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
cast,
@@ -691,7 +693,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
self, user_id: str, device_id: str
) -> List[str]:
) -> Sequence[str]:
"""Returns the fallback key types that have an unused key.

Args:
@@ -731,7 +733,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return user_keys.get(key_type)

@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
@@ -744,7 +746,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str]
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
@@ -765,7 +767,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)

# The `Optional` comes from the `@cachedList` decorator.
return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)

def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
@@ -924,7 +926,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users.

Args:
@@ -940,11 +942,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)

if from_user_id:
result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
from_user_id,
result = cast(
Dict[str, Optional[Mapping[str, JsonDict]]],
await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
from_user_id,
),
)

return result


+ 7
- 4
synapse/storage/databases/main/event_federation.py Bestand weergeven

@@ -22,6 +22,7 @@ from typing import (
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
cast,
@@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)

async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
async def get_max_depth_of(
self, event_ids: Collection[str]
) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs

Args:
@@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)

@cached(max_entries=5000, iterable=True)
async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
@@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cancellable
async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int
) -> List[str]:
) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".

@@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cached(max_entries=5000, num_args=2)
async def _get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int
) -> List[str]:
) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".



+ 2
- 2
synapse/storage/databases/main/monthly_active_users.py Bestand weergeven

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast

from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
@@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
return await self.db_pool.runInteraction("count_users", _count_users)

@cached(num_args=0)
async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
async def get_monthly_active_count_by_service(self) -> Mapping[str, int]:
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table


+ 6
- 4
synapse/storage/databases/main/receipts.py Bestand weergeven

@@ -21,7 +21,9 @@ from typing import (
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
cast,
)
@@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore):

async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
) -> Sequence[JsonDict]:
"""Get receipts for a single room for sending to clients.

Args:
@@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[JsonDict]:
) -> Sequence[JsonDict]:
"""See get_linearized_receipts_for_room"""

def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
) -> Dict[str, List[JsonDict]]:
) -> Dict[str, Sequence[JsonDict]]:
if not room_ids:
return {}

@@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]:
) -> Mapping[str, JsonDict]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.



+ 2
- 2
synapse/storage/databases/main/registration.py Bestand weergeven

@@ -16,7 +16,7 @@
import logging
import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast

import attr

@@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)

@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""

def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:


+ 5
- 2
synapse/storage/databases/main/relations.py Bestand weergeven

@@ -22,6 +22,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
@@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.

Args:
@@ -397,7 +398,9 @@ class RelationsWorkerStore(SQLBaseStore):
return result is not None

@cached()
async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
async def get_aggregation_groups_for_event(
self, event_id: str
) -> Sequence[JsonDict]:
raise NotImplementedError()

@cachedList(


+ 9
- 10
synapse/storage/databases/main/roommember.py Bestand weergeven

@@ -24,6 +24,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
@@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self._known_servers_count

@cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]:
async def get_users_in_room(self, room_id: str) -> Sequence[str]:
"""Returns a list of users in the room.

Will return inaccurate results for rooms with partial state, since the state for
@@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)

@cached()
def get_user_in_room_with_profile(
self, room_id: str, user_id: str
) -> Dict[str, ProfileInfo]:
def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo:
raise NotImplementedError()

@cachedList(
@@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000, iterable=True)
async def get_users_in_room_with_profiles(
self, room_id: str
) -> Dict[str, ProfileInfo]:
) -> Mapping[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for all users in a given room.

The profile information comes directly from this room's `m.room.member`
@@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)

@cached(max_entries=100000)
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
"""Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
@@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached()
async def get_invited_rooms_for_local_user(
self, user_id: str
) -> List[RoomsForUser]:
) -> Sequence[RoomsForUser]:
"""Get all the rooms the *local* user is invited to.

Args:
@@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results

@cached(iterable=True)
async def get_local_users_in_room(self, room_id: str) -> List[str]:
async def get_local_users_in_room(self, room_id: str) -> Sequence[str]:
"""
Retrieves a list of the current roommembers who are local to the server.
"""
@@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns the set of users who share a room with `user_id`"""
room_ids = await self.get_rooms_for_user(user_id)

user_who_share_room = set()
user_who_share_room: Set[str] = set()
for room_id in room_ids:
user_ids = await self.get_users_in_room(room_id)
user_who_share_room.update(user_ids)
@@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True

@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state."""

# First we check if we already have `get_users_in_room` in the cache, as


+ 3
- 3
synapse/storage/databases/main/signatures.py Bestand weergeven

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Collection, Dict, List, Tuple
from typing import Collection, Dict, List, Mapping, Tuple

from unpaddedbase64 import encode_base64

@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList

class SignatureWorkerStore(EventsWorkerStore):
@cached()
def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]:
# This is a dummy function to allow get_event_reference_hashes
# to use its cache
raise NotImplementedError()
@@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore):
)
async def get_event_reference_hashes(
self, event_ids: Collection[str]
) -> Dict[str, Dict[str, bytes]]:
) -> Mapping[str, Mapping[str, bytes]]:
"""Get all hashes for given events.

Args:


+ 5
- 3
synapse/storage/databases/main/tags.py Bestand weergeven

@@ -15,7 +15,7 @@
# limitations under the License.

import logging
from typing import Any, Dict, Iterable, List, Tuple, cast
from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast

from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream
@@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)

class TagsWorkerStore(AccountDataWorkerStore):
@cached()
async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
async def get_tags_for_user(
self, user_id: str
) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for a user.


@@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):

async def get_updated_tags(
self, user_id: str, stream_id: int
) -> Dict[str, Dict[str, JsonDict]]:
) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version



+ 2
- 2
synapse/storage/databases/main/user_directory.py Bestand weergeven

@@ -16,9 +16,9 @@ import logging
import re
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
@@ -586,7 +586,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)

@cached()
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},


+ 2
- 2
tests/rest/admin/test_server_notice.py Bestand weergeven

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import List, Sequence

from twisted.test.proto_helpers import MemoryReactor

@@ -558,7 +558,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):

def _check_invite_and_join_status(
self, user_id: str, expected_invites: int, expected_memberships: int
) -> List[RoomsForUser]:
) -> Sequence[RoomsForUser]:
"""Check invite and room membership status of a user.

Args


Laden…
Annuleren
Opslaan