@@ -0,0 +1 @@ | |||||
Improve type hints. |
@@ -23,7 +23,7 @@ from netaddr import IPSet | |||||
from synapse.api.constants import EventTypes | from synapse.api.constants import EventTypes | ||||
from synapse.events import EventBase | from synapse.events import EventBase | ||||
from synapse.types import DeviceListUpdates, JsonDict, UserID | |||||
from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, UserID | |||||
from synapse.util.caches.descriptors import _CacheContext, cached | from synapse.util.caches.descriptors import _CacheContext, cached | ||||
if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
@@ -379,8 +379,8 @@ class AppServiceTransaction: | |||||
service: ApplicationService, | service: ApplicationService, | ||||
id: int, | id: int, | ||||
events: Sequence[EventBase], | events: Sequence[EventBase], | ||||
ephemeral: List[JsonDict], | |||||
to_device_messages: List[JsonDict], | |||||
ephemeral: List[JsonMapping], | |||||
to_device_messages: List[JsonMapping], | |||||
one_time_keys_count: TransactionOneTimeKeysCount, | one_time_keys_count: TransactionOneTimeKeysCount, | ||||
unused_fallback_keys: TransactionUnusedFallbackKeys, | unused_fallback_keys: TransactionUnusedFallbackKeys, | ||||
device_list_summary: DeviceListUpdates, | device_list_summary: DeviceListUpdates, | ||||
@@ -41,7 +41,7 @@ from synapse.events import EventBase | |||||
from synapse.events.utils import SerializeEventConfig, serialize_event | from synapse.events.utils import SerializeEventConfig, serialize_event | ||||
from synapse.http.client import SimpleHttpClient, is_unknown_endpoint | from synapse.http.client import SimpleHttpClient, is_unknown_endpoint | ||||
from synapse.logging import opentracing | from synapse.logging import opentracing | ||||
from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID | |||||
from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID | |||||
from synapse.util.caches.response_cache import ResponseCache | from synapse.util.caches.response_cache import ResponseCache | ||||
if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
@@ -306,8 +306,8 @@ class ApplicationServiceApi(SimpleHttpClient): | |||||
self, | self, | ||||
service: "ApplicationService", | service: "ApplicationService", | ||||
events: Sequence[EventBase], | events: Sequence[EventBase], | ||||
ephemeral: List[JsonDict], | |||||
to_device_messages: List[JsonDict], | |||||
ephemeral: List[JsonMapping], | |||||
to_device_messages: List[JsonMapping], | |||||
one_time_keys_count: TransactionOneTimeKeysCount, | one_time_keys_count: TransactionOneTimeKeysCount, | ||||
unused_fallback_keys: TransactionUnusedFallbackKeys, | unused_fallback_keys: TransactionUnusedFallbackKeys, | ||||
device_list_summary: DeviceListUpdates, | device_list_summary: DeviceListUpdates, | ||||
@@ -73,7 +73,7 @@ from synapse.events import EventBase | |||||
from synapse.logging.context import run_in_background | from synapse.logging.context import run_in_background | ||||
from synapse.metrics.background_process_metrics import run_as_background_process | from synapse.metrics.background_process_metrics import run_as_background_process | ||||
from synapse.storage.databases.main import DataStore | from synapse.storage.databases.main import DataStore | ||||
from synapse.types import DeviceListUpdates, JsonDict | |||||
from synapse.types import DeviceListUpdates, JsonMapping | |||||
from synapse.util import Clock | from synapse.util import Clock | ||||
if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
@@ -121,8 +121,8 @@ class ApplicationServiceScheduler: | |||||
self, | self, | ||||
appservice: ApplicationService, | appservice: ApplicationService, | ||||
events: Optional[Collection[EventBase]] = None, | events: Optional[Collection[EventBase]] = None, | ||||
ephemeral: Optional[Collection[JsonDict]] = None, | |||||
to_device_messages: Optional[Collection[JsonDict]] = None, | |||||
ephemeral: Optional[Collection[JsonMapping]] = None, | |||||
to_device_messages: Optional[Collection[JsonMapping]] = None, | |||||
device_list_summary: Optional[DeviceListUpdates] = None, | device_list_summary: Optional[DeviceListUpdates] = None, | ||||
) -> None: | ) -> None: | ||||
""" | """ | ||||
@@ -180,9 +180,9 @@ class _ServiceQueuer: | |||||
# dict of {service_id: [events]} | # dict of {service_id: [events]} | ||||
self.queued_events: Dict[str, List[EventBase]] = {} | self.queued_events: Dict[str, List[EventBase]] = {} | ||||
# dict of {service_id: [events]} | # dict of {service_id: [events]} | ||||
self.queued_ephemeral: Dict[str, List[JsonDict]] = {} | |||||
self.queued_ephemeral: Dict[str, List[JsonMapping]] = {} | |||||
# dict of {service_id: [to_device_message_json]} | # dict of {service_id: [to_device_message_json]} | ||||
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} | |||||
self.queued_to_device_messages: Dict[str, List[JsonMapping]] = {} | |||||
# dict of {service_id: [device_list_summary]} | # dict of {service_id: [device_list_summary]} | ||||
self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {} | self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {} | ||||
@@ -293,8 +293,8 @@ class _ServiceQueuer: | |||||
self, | self, | ||||
service: ApplicationService, | service: ApplicationService, | ||||
events: Iterable[EventBase], | events: Iterable[EventBase], | ||||
ephemerals: Iterable[JsonDict], | |||||
to_device_messages: Iterable[JsonDict], | |||||
ephemerals: Iterable[JsonMapping], | |||||
to_device_messages: Iterable[JsonMapping], | |||||
) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]: | ) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]: | ||||
""" | """ | ||||
Given a list of the events, ephemeral messages and to-device messages, | Given a list of the events, ephemeral messages and to-device messages, | ||||
@@ -364,8 +364,8 @@ class _TransactionController: | |||||
self, | self, | ||||
service: ApplicationService, | service: ApplicationService, | ||||
events: Sequence[EventBase], | events: Sequence[EventBase], | ||||
ephemeral: Optional[List[JsonDict]] = None, | |||||
to_device_messages: Optional[List[JsonDict]] = None, | |||||
ephemeral: Optional[List[JsonMapping]] = None, | |||||
to_device_messages: Optional[List[JsonMapping]] = None, | |||||
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, | one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, | ||||
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, | unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, | ||||
device_list_summary: Optional[DeviceListUpdates] = None, | device_list_summary: Optional[DeviceListUpdates] = None, | ||||
@@ -46,6 +46,7 @@ from synapse.storage.databases.main.directory import RoomAliasMapping | |||||
from synapse.types import ( | from synapse.types import ( | ||||
DeviceListUpdates, | DeviceListUpdates, | ||||
JsonDict, | JsonDict, | ||||
JsonMapping, | |||||
RoomAlias, | RoomAlias, | ||||
RoomStreamToken, | RoomStreamToken, | ||||
StreamKeyType, | StreamKeyType, | ||||
@@ -397,7 +398,7 @@ class ApplicationServicesHandler: | |||||
async def _handle_typing( | async def _handle_typing( | ||||
self, service: ApplicationService, new_token: int | self, service: ApplicationService, new_token: int | ||||
) -> List[JsonDict]: | |||||
) -> List[JsonMapping]: | |||||
""" | """ | ||||
Return the typing events since the given stream token that the given application | Return the typing events since the given stream token that the given application | ||||
service should receive. | service should receive. | ||||
@@ -432,7 +433,7 @@ class ApplicationServicesHandler: | |||||
async def _handle_receipts( | async def _handle_receipts( | ||||
self, service: ApplicationService, new_token: int | self, service: ApplicationService, new_token: int | ||||
) -> List[JsonDict]: | |||||
) -> List[JsonMapping]: | |||||
""" | """ | ||||
Return the latest read receipts that the given application service should receive. | Return the latest read receipts that the given application service should receive. | ||||
@@ -471,7 +472,7 @@ class ApplicationServicesHandler: | |||||
service: ApplicationService, | service: ApplicationService, | ||||
users: Collection[Union[str, UserID]], | users: Collection[Union[str, UserID]], | ||||
new_token: Optional[int], | new_token: Optional[int], | ||||
) -> List[JsonDict]: | |||||
) -> List[JsonMapping]: | |||||
""" | """ | ||||
Return the latest presence updates that the given application service should receive. | Return the latest presence updates that the given application service should receive. | ||||
@@ -491,7 +492,7 @@ class ApplicationServicesHandler: | |||||
A list of json dictionaries containing data derived from the presence events | A list of json dictionaries containing data derived from the presence events | ||||
that should be sent to the given application service. | that should be sent to the given application service. | ||||
""" | """ | ||||
events: List[JsonDict] = [] | |||||
events: List[JsonMapping] = [] | |||||
presence_source = self.event_sources.sources.presence | presence_source = self.event_sources.sources.presence | ||||
from_key = await self.store.get_type_stream_id_for_appservice( | from_key = await self.store.get_type_stream_id_for_appservice( | ||||
service, "presence" | service, "presence" | ||||
@@ -14,7 +14,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
import logging | import logging | ||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple | |||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple | |||||
import attr | import attr | ||||
from canonicaljson import encode_canonical_json | from canonicaljson import encode_canonical_json | ||||
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background | |||||
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace | from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace | ||||
from synapse.types import ( | from synapse.types import ( | ||||
JsonDict, | JsonDict, | ||||
JsonMapping, | |||||
UserID, | UserID, | ||||
get_domain_from_id, | get_domain_from_id, | ||||
get_verify_key_from_cross_signing_key, | get_verify_key_from_cross_signing_key, | ||||
@@ -272,11 +273,7 @@ class E2eKeysHandler: | |||||
delay_cancellation=True, | delay_cancellation=True, | ||||
) | ) | ||||
ret = {"device_keys": results, "failures": failures} | |||||
ret.update(cross_signing_keys) | |||||
return ret | |||||
return {"device_keys": results, "failures": failures, **cross_signing_keys} | |||||
@trace | @trace | ||||
async def _query_devices_for_destination( | async def _query_devices_for_destination( | ||||
@@ -408,7 +405,7 @@ class E2eKeysHandler: | |||||
@cancellable | @cancellable | ||||
async def get_cross_signing_keys_from_cache( | async def get_cross_signing_keys_from_cache( | ||||
self, query: Iterable[str], from_user_id: Optional[str] | self, query: Iterable[str], from_user_id: Optional[str] | ||||
) -> Dict[str, Dict[str, dict]]: | |||||
) -> Dict[str, Dict[str, JsonMapping]]: | |||||
"""Get cross-signing keys for users from the database | """Get cross-signing keys for users from the database | ||||
Args: | Args: | ||||
@@ -551,16 +548,13 @@ class E2eKeysHandler: | |||||
self.config.federation.allow_device_name_lookup_over_federation | self.config.federation.allow_device_name_lookup_over_federation | ||||
), | ), | ||||
) | ) | ||||
ret = {"device_keys": res} | |||||
# add in the cross-signing keys | # add in the cross-signing keys | ||||
cross_signing_keys = await self.get_cross_signing_keys_from_cache( | cross_signing_keys = await self.get_cross_signing_keys_from_cache( | ||||
device_keys_query, None | device_keys_query, None | ||||
) | ) | ||||
ret.update(cross_signing_keys) | |||||
return ret | |||||
return {"device_keys": res, **cross_signing_keys} | |||||
async def claim_local_one_time_keys( | async def claim_local_one_time_keys( | ||||
self, | self, | ||||
@@ -1127,7 +1121,7 @@ class E2eKeysHandler: | |||||
user_id: str, | user_id: str, | ||||
master_key_id: str, | master_key_id: str, | ||||
signed_master_key: JsonDict, | signed_master_key: JsonDict, | ||||
stored_master_key: JsonDict, | |||||
stored_master_key: JsonMapping, | |||||
devices: Dict[str, Dict[str, JsonDict]], | devices: Dict[str, Dict[str, JsonDict]], | ||||
) -> List["SignatureListItem"]: | ) -> List["SignatureListItem"]: | ||||
"""Check signatures of a user's master key made by their devices. | """Check signatures of a user's master key made by their devices. | ||||
@@ -1278,7 +1272,7 @@ class E2eKeysHandler: | |||||
async def _get_e2e_cross_signing_verify_key( | async def _get_e2e_cross_signing_verify_key( | ||||
self, user_id: str, key_type: str, from_user_id: Optional[str] = None | self, user_id: str, key_type: str, from_user_id: Optional[str] = None | ||||
) -> Tuple[JsonDict, str, VerifyKey]: | |||||
) -> Tuple[JsonMapping, str, VerifyKey]: | |||||
"""Fetch locally or remotely query for a cross-signing public key. | """Fetch locally or remotely query for a cross-signing public key. | ||||
First, attempt to fetch the cross-signing public key from storage. | First, attempt to fetch the cross-signing public key from storage. | ||||
@@ -1333,7 +1327,7 @@ class E2eKeysHandler: | |||||
self, | self, | ||||
user: UserID, | user: UserID, | ||||
desired_key_type: str, | desired_key_type: str, | ||||
) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]: | |||||
) -> Optional[Tuple[JsonMapping, str, VerifyKey]]: | |||||
"""Queries cross-signing keys for a remote user and saves them to the database | """Queries cross-signing keys for a remote user and saves them to the database | ||||
Only the key specified by `key_type` will be returned, while all retrieved keys | Only the key specified by `key_type` will be returned, while all retrieved keys | ||||
@@ -1474,7 +1468,7 @@ def _check_device_signature( | |||||
user_id: str, | user_id: str, | ||||
verify_key: VerifyKey, | verify_key: VerifyKey, | ||||
signed_device: JsonDict, | signed_device: JsonDict, | ||||
stored_device: JsonDict, | |||||
stored_device: JsonMapping, | |||||
) -> None: | ) -> None: | ||||
"""Check that a signature on a device or cross-signing key is correct and | """Check that a signature on a device or cross-signing key is correct and | ||||
matches the copy of the device/key that we have stored. Throws an | matches the copy of the device/key that we have stored. Throws an | ||||
@@ -32,6 +32,7 @@ from synapse.storage.roommember import RoomsForUser | |||||
from synapse.streams.config import PaginationConfig | from synapse.streams.config import PaginationConfig | ||||
from synapse.types import ( | from synapse.types import ( | ||||
JsonDict, | JsonDict, | ||||
JsonMapping, | |||||
Requester, | Requester, | ||||
RoomStreamToken, | RoomStreamToken, | ||||
StreamKeyType, | StreamKeyType, | ||||
@@ -454,7 +455,7 @@ class InitialSyncHandler: | |||||
for s in states | for s in states | ||||
] | ] | ||||
async def get_receipts() -> List[JsonDict]: | |||||
async def get_receipts() -> List[JsonMapping]: | |||||
receipts = await self.store.get_linearized_receipts_for_room( | receipts = await self.store.get_linearized_receipts_for_room( | ||||
room_id, to_key=now_token.receipt_key | room_id, to_key=now_token.receipt_key | ||||
) | ) | ||||
@@ -19,6 +19,7 @@ from synapse.appservice import ApplicationService | |||||
from synapse.streams import EventSource | from synapse.streams import EventSource | ||||
from synapse.types import ( | from synapse.types import ( | ||||
JsonDict, | JsonDict, | ||||
JsonMapping, | |||||
ReadReceipt, | ReadReceipt, | ||||
StreamKeyType, | StreamKeyType, | ||||
UserID, | UserID, | ||||
@@ -204,15 +205,15 @@ class ReceiptsHandler: | |||||
await self.federation_sender.send_read_receipt(receipt) | await self.federation_sender.send_read_receipt(receipt) | ||||
class ReceiptEventSource(EventSource[int, JsonDict]): | |||||
class ReceiptEventSource(EventSource[int, JsonMapping]): | |||||
def __init__(self, hs: "HomeServer"): | def __init__(self, hs: "HomeServer"): | ||||
self.store = hs.get_datastores().main | self.store = hs.get_datastores().main | ||||
self.config = hs.config | self.config = hs.config | ||||
@staticmethod | @staticmethod | ||||
def filter_out_private_receipts( | def filter_out_private_receipts( | ||||
rooms: Sequence[JsonDict], user_id: str | |||||
) -> List[JsonDict]: | |||||
rooms: Sequence[JsonMapping], user_id: str | |||||
) -> List[JsonMapping]: | |||||
""" | """ | ||||
Filters a list of serialized receipts (as returned by /sync and /initialSync) | Filters a list of serialized receipts (as returned by /sync and /initialSync) | ||||
and removes private read receipts of other users. | and removes private read receipts of other users. | ||||
@@ -229,7 +230,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): | |||||
The same as rooms, but filtered. | The same as rooms, but filtered. | ||||
""" | """ | ||||
result = [] | |||||
result: List[JsonMapping] = [] | |||||
# Iterate through each room's receipt content. | # Iterate through each room's receipt content. | ||||
for room in rooms: | for room in rooms: | ||||
@@ -282,7 +283,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): | |||||
room_ids: Iterable[str], | room_ids: Iterable[str], | ||||
is_guest: bool, | is_guest: bool, | ||||
explicit_room_id: Optional[str] = None, | explicit_room_id: Optional[str] = None, | ||||
) -> Tuple[List[JsonDict], int]: | |||||
) -> Tuple[List[JsonMapping], int]: | |||||
from_key = int(from_key) | from_key = int(from_key) | ||||
to_key = self.get_current_key() | to_key = self.get_current_key() | ||||
@@ -301,7 +302,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): | |||||
async def get_new_events_as( | async def get_new_events_as( | ||||
self, from_key: int, to_key: int, service: ApplicationService | self, from_key: int, to_key: int, service: ApplicationService | ||||
) -> Tuple[List[JsonDict], int]: | |||||
) -> Tuple[List[JsonMapping], int]: | |||||
"""Returns a set of new read receipt events that an appservice | """Returns a set of new read receipt events that an appservice | ||||
may be interested in. | may be interested in. | ||||
@@ -235,7 +235,7 @@ class SyncResult: | |||||
archived: List[ArchivedSyncResult] | archived: List[ArchivedSyncResult] | ||||
to_device: List[JsonDict] | to_device: List[JsonDict] | ||||
device_lists: DeviceListUpdates | device_lists: DeviceListUpdates | ||||
device_one_time_keys_count: JsonDict | |||||
device_one_time_keys_count: JsonMapping | |||||
device_unused_fallback_key_types: List[str] | device_unused_fallback_key_types: List[str] | ||||
def __bool__(self) -> bool: | def __bool__(self) -> bool: | ||||
@@ -1558,7 +1558,7 @@ class SyncHandler: | |||||
logger.debug("Fetching OTK data") | logger.debug("Fetching OTK data") | ||||
device_id = sync_config.device_id | device_id = sync_config.device_id | ||||
one_time_keys_count: JsonDict = {} | |||||
one_time_keys_count: JsonMapping = {} | |||||
unused_fallback_key_types: List[str] = [] | unused_fallback_key_types: List[str] = [] | ||||
if device_id: | if device_id: | ||||
# TODO: We should have a way to let clients differentiate between the states of: | # TODO: We should have a way to let clients differentiate between the states of: | ||||
@@ -26,7 +26,14 @@ from synapse.metrics.background_process_metrics import ( | |||||
) | ) | ||||
from synapse.replication.tcp.streams import TypingStream | from synapse.replication.tcp.streams import TypingStream | ||||
from synapse.streams import EventSource | from synapse.streams import EventSource | ||||
from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID | |||||
from synapse.types import ( | |||||
JsonDict, | |||||
JsonMapping, | |||||
Requester, | |||||
StrCollection, | |||||
StreamKeyType, | |||||
UserID, | |||||
) | |||||
from synapse.util.caches.stream_change_cache import StreamChangeCache | from synapse.util.caches.stream_change_cache import StreamChangeCache | ||||
from synapse.util.metrics import Measure | from synapse.util.metrics import Measure | ||||
from synapse.util.retryutils import filter_destinations_by_retry_limiter | from synapse.util.retryutils import filter_destinations_by_retry_limiter | ||||
@@ -487,7 +494,7 @@ class TypingWriterHandler(FollowerTypingHandler): | |||||
raise Exception("Typing writer instance got typing info over replication") | raise Exception("Typing writer instance got typing info over replication") | ||||
class TypingNotificationEventSource(EventSource[int, JsonDict]): | |||||
class TypingNotificationEventSource(EventSource[int, JsonMapping]): | |||||
def __init__(self, hs: "HomeServer"): | def __init__(self, hs: "HomeServer"): | ||||
self._main_store = hs.get_datastores().main | self._main_store = hs.get_datastores().main | ||||
self.clock = hs.get_clock() | self.clock = hs.get_clock() | ||||
@@ -497,7 +504,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): | |||||
# | # | ||||
self.get_typing_handler = hs.get_typing_handler | self.get_typing_handler = hs.get_typing_handler | ||||
def _make_event_for(self, room_id: str) -> JsonDict: | |||||
def _make_event_for(self, room_id: str) -> JsonMapping: | |||||
typing = self.get_typing_handler()._room_typing[room_id] | typing = self.get_typing_handler()._room_typing[room_id] | ||||
return { | return { | ||||
"type": EduTypes.TYPING, | "type": EduTypes.TYPING, | ||||
@@ -507,7 +514,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): | |||||
async def get_new_events_as( | async def get_new_events_as( | ||||
self, from_key: int, service: ApplicationService | self, from_key: int, service: ApplicationService | ||||
) -> Tuple[List[JsonDict], int]: | |||||
) -> Tuple[List[JsonMapping], int]: | |||||
"""Returns a set of new typing events that an appservice | """Returns a set of new typing events that an appservice | ||||
may be interested in. | may be interested in. | ||||
@@ -551,7 +558,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): | |||||
room_ids: Iterable[str], | room_ids: Iterable[str], | ||||
is_guest: bool, | is_guest: bool, | ||||
explicit_room_id: Optional[str] = None, | explicit_room_id: Optional[str] = None, | ||||
) -> Tuple[List[JsonDict], int]: | |||||
) -> Tuple[List[JsonMapping], int]: | |||||
with Measure(self.clock, "typing.get_new_events"): | with Measure(self.clock, "typing.get_new_events"): | ||||
from_key = int(from_key) | from_key = int(from_key) | ||||
handler = self.get_typing_handler() | handler = self.get_typing_handler() | ||||
@@ -131,7 +131,7 @@ class BulkPushRuleEvaluator: | |||||
async def _get_rules_for_event( | async def _get_rules_for_event( | ||||
self, | self, | ||||
event: EventBase, | event: EventBase, | ||||
) -> Dict[str, FilteredPushRules]: | |||||
) -> Mapping[str, FilteredPushRules]: | |||||
"""Get the push rules for all users who may need to be notified about | """Get the push rules for all users who may need to be notified about | ||||
the event. | the event. | ||||
@@ -45,7 +45,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore | |||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore | from synapse.storage.databases.main.roommember import RoomMemberWorkerStore | ||||
from synapse.storage.types import Cursor | from synapse.storage.types import Cursor | ||||
from synapse.storage.util.sequence import build_sequence_generator | from synapse.storage.util.sequence import build_sequence_generator | ||||
from synapse.types import DeviceListUpdates, JsonDict | |||||
from synapse.types import DeviceListUpdates, JsonMapping | |||||
from synapse.util import json_encoder | from synapse.util import json_encoder | ||||
from synapse.util.caches.descriptors import _CacheContext, cached | from synapse.util.caches.descriptors import _CacheContext, cached | ||||
@@ -268,8 +268,8 @@ class ApplicationServiceTransactionWorkerStore( | |||||
self, | self, | ||||
service: ApplicationService, | service: ApplicationService, | ||||
events: Sequence[EventBase], | events: Sequence[EventBase], | ||||
ephemeral: List[JsonDict], | |||||
to_device_messages: List[JsonDict], | |||||
ephemeral: List[JsonMapping], | |||||
to_device_messages: List[JsonMapping], | |||||
one_time_keys_count: TransactionOneTimeKeysCount, | one_time_keys_count: TransactionOneTimeKeysCount, | ||||
unused_fallback_keys: TransactionUnusedFallbackKeys, | unused_fallback_keys: TransactionUnusedFallbackKeys, | ||||
device_list_summary: DeviceListUpdates, | device_list_summary: DeviceListUpdates, | ||||
@@ -55,7 +55,12 @@ from synapse.storage.util.id_generators import ( | |||||
AbstractStreamIdGenerator, | AbstractStreamIdGenerator, | ||||
StreamIdGenerator, | StreamIdGenerator, | ||||
) | ) | ||||
from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key | |||||
from synapse.types import ( | |||||
JsonDict, | |||||
JsonMapping, | |||||
StrCollection, | |||||
get_verify_key_from_cross_signing_key, | |||||
) | |||||
from synapse.util import json_decoder, json_encoder | from synapse.util import json_decoder, json_encoder | ||||
from synapse.util.caches.descriptors import cached, cachedList | from synapse.util.caches.descriptors import cached, cachedList | ||||
from synapse.util.caches.lrucache import LruCache | from synapse.util.caches.lrucache import LruCache | ||||
@@ -746,7 +751,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||||
@cancellable | @cancellable | ||||
async def get_user_devices_from_cache( | async def get_user_devices_from_cache( | ||||
self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] | self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] | ||||
) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]: | |||||
) -> Tuple[Set[str], Dict[str, Mapping[str, JsonMapping]]]: | |||||
"""Get the devices (and keys if any) for remote users from the cache. | """Get the devices (and keys if any) for remote users from the cache. | ||||
Args: | Args: | ||||
@@ -766,13 +771,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||||
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache | user_ids_not_in_cache = unique_user_ids - user_ids_in_cache | ||||
# First fetch all the users which all devices are to be returned. | # First fetch all the users which all devices are to be returned. | ||||
results: Dict[str, Mapping[str, JsonDict]] = {} | |||||
results: Dict[str, Mapping[str, JsonMapping]] = {} | |||||
for user_id in user_ids: | for user_id in user_ids: | ||||
if user_id in user_ids_in_cache: | if user_id in user_ids_in_cache: | ||||
results[user_id] = await self.get_cached_devices_for_user(user_id) | results[user_id] = await self.get_cached_devices_for_user(user_id) | ||||
# Then fetch all device-specific requests, but skip users we've already | # Then fetch all device-specific requests, but skip users we've already | ||||
# fetched all devices for. | # fetched all devices for. | ||||
device_specific_results: Dict[str, Dict[str, JsonDict]] = {} | |||||
device_specific_results: Dict[str, Dict[str, JsonMapping]] = {} | |||||
for user_id, device_id in user_and_device_ids: | for user_id, device_id in user_and_device_ids: | ||||
if user_id in user_ids_in_cache and user_id not in user_ids: | if user_id in user_ids_in_cache and user_id not in user_ids: | ||||
device = await self._get_cached_user_device(user_id, device_id) | device = await self._get_cached_user_device(user_id, device_id) | ||||
@@ -801,7 +806,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||||
return user_ids_in_cache | return user_ids_in_cache | ||||
@cached(num_args=2, tree=True) | @cached(num_args=2, tree=True) | ||||
async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: | |||||
async def _get_cached_user_device( | |||||
self, user_id: str, device_id: str | |||||
) -> JsonMapping: | |||||
content = await self.db_pool.simple_select_one_onecol( | content = await self.db_pool.simple_select_one_onecol( | ||||
table="device_lists_remote_cache", | table="device_lists_remote_cache", | ||||
keyvalues={"user_id": user_id, "device_id": device_id}, | keyvalues={"user_id": user_id, "device_id": device_id}, | ||||
@@ -811,7 +818,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||||
return db_to_json(content) | return db_to_json(content) | ||||
@cached() | @cached() | ||||
async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]: | |||||
async def get_cached_devices_for_user( | |||||
self, user_id: str | |||||
) -> Mapping[str, JsonMapping]: | |||||
devices = await self.db_pool.simple_select_list( | devices = await self.db_pool.simple_select_list( | ||||
table="device_lists_remote_cache", | table="device_lists_remote_cache", | ||||
keyvalues={"user_id": user_id}, | keyvalues={"user_id": user_id}, | ||||
@@ -1042,7 +1051,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): | |||||
) | ) | ||||
async def get_device_list_last_stream_id_for_remotes( | async def get_device_list_last_stream_id_for_remotes( | ||||
self, user_ids: Iterable[str] | self, user_ids: Iterable[str] | ||||
) -> Dict[str, Optional[str]]: | |||||
) -> Mapping[str, Optional[str]]: | |||||
rows = await self.db_pool.simple_select_many_batch( | rows = await self.db_pool.simple_select_many_batch( | ||||
table="device_lists_remote_extremeties", | table="device_lists_remote_extremeties", | ||||
column="user_id", | column="user_id", | ||||
@@ -52,7 +52,7 @@ from synapse.storage.database import ( | |||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | ||||
from synapse.storage.engines import PostgresEngine | from synapse.storage.engines import PostgresEngine | ||||
from synapse.storage.util.id_generators import StreamIdGenerator | from synapse.storage.util.id_generators import StreamIdGenerator | ||||
from synapse.types import JsonDict | |||||
from synapse.types import JsonDict, JsonMapping | |||||
from synapse.util import json_decoder, json_encoder | from synapse.util import json_decoder, json_encoder | ||||
from synapse.util.caches.descriptors import cached, cachedList | from synapse.util.caches.descriptors import cached, cachedList | ||||
from synapse.util.cancellation import cancellable | from synapse.util.cancellation import cancellable | ||||
@@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||||
async def get_e2e_device_keys_for_federation_query( | async def get_e2e_device_keys_for_federation_query( | ||||
self, user_id: str | self, user_id: str | ||||
) -> Tuple[int, List[JsonDict]]: | |||||
) -> Tuple[int, Sequence[JsonMapping]]: | |||||
"""Get all devices (with any device keys) for a user | """Get all devices (with any device keys) for a user | ||||
Returns: | Returns: | ||||
@@ -174,7 +174,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||||
@cached(iterable=True) | @cached(iterable=True) | ||||
async def _get_e2e_device_keys_for_federation_query_inner( | async def _get_e2e_device_keys_for_federation_query_inner( | ||||
self, user_id: str | self, user_id: str | ||||
) -> List[JsonDict]: | |||||
) -> Sequence[JsonMapping]: | |||||
"""Get all devices (with any device keys) for a user""" | """Get all devices (with any device keys) for a user""" | ||||
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) | devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) | ||||
@@ -578,7 +578,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||||
@cached(max_entries=10000) | @cached(max_entries=10000) | ||||
async def count_e2e_one_time_keys( | async def count_e2e_one_time_keys( | ||||
self, user_id: str, device_id: str | self, user_id: str, device_id: str | ||||
) -> Dict[str, int]: | |||||
) -> Mapping[str, int]: | |||||
"""Count the number of one time keys the server has for a device | """Count the number of one time keys the server has for a device | ||||
Returns: | Returns: | ||||
A mapping from algorithm to number of keys for that algorithm. | A mapping from algorithm to number of keys for that algorithm. | ||||
@@ -812,7 +812,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||||
async def get_e2e_cross_signing_key( | async def get_e2e_cross_signing_key( | ||||
self, user_id: str, key_type: str, from_user_id: Optional[str] = None | self, user_id: str, key_type: str, from_user_id: Optional[str] = None | ||||
) -> Optional[JsonDict]: | |||||
) -> Optional[JsonMapping]: | |||||
"""Returns a user's cross-signing key. | """Returns a user's cross-signing key. | ||||
Args: | Args: | ||||
@@ -833,7 +833,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||||
return user_keys.get(key_type) | return user_keys.get(key_type) | ||||
@cached(num_args=1) | @cached(num_args=1) | ||||
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]: | |||||
def _get_bare_e2e_cross_signing_keys( | |||||
self, user_id: str | |||||
) -> Mapping[str, JsonMapping]: | |||||
"""Dummy function. Only used to make a cache for | """Dummy function. Only used to make a cache for | ||||
_get_bare_e2e_cross_signing_keys_bulk. | _get_bare_e2e_cross_signing_keys_bulk. | ||||
""" | """ | ||||
@@ -846,7 +848,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||||
) | ) | ||||
async def _get_bare_e2e_cross_signing_keys_bulk( | async def _get_bare_e2e_cross_signing_keys_bulk( | ||||
self, user_ids: Iterable[str] | self, user_ids: Iterable[str] | ||||
) -> Dict[str, Optional[Mapping[str, JsonDict]]]: | |||||
) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]: | |||||
"""Returns the cross-signing keys for a set of users. The output of this | """Returns the cross-signing keys for a set of users. The output of this | ||||
function should be passed to _get_e2e_cross_signing_signatures_txn if | function should be passed to _get_e2e_cross_signing_signatures_txn if | ||||
the signatures for the calling user need to be fetched. | the signatures for the calling user need to be fetched. | ||||
@@ -860,15 +862,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||||
their user ID will map to None. | their user ID will map to None. | ||||
""" | """ | ||||
result = await self.db_pool.runInteraction( | |||||
return await self.db_pool.runInteraction( | |||||
"get_bare_e2e_cross_signing_keys_bulk", | "get_bare_e2e_cross_signing_keys_bulk", | ||||
self._get_bare_e2e_cross_signing_keys_bulk_txn, | self._get_bare_e2e_cross_signing_keys_bulk_txn, | ||||
user_ids, | user_ids, | ||||
) | ) | ||||
# The `Optional` comes from the `@cachedList` decorator. | |||||
return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result) | |||||
def _get_bare_e2e_cross_signing_keys_bulk_txn( | def _get_bare_e2e_cross_signing_keys_bulk_txn( | ||||
self, | self, | ||||
txn: LoggingTransaction, | txn: LoggingTransaction, | ||||
@@ -1026,7 +1025,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||||
@cancellable | @cancellable | ||||
async def get_e2e_cross_signing_keys_bulk( | async def get_e2e_cross_signing_keys_bulk( | ||||
self, user_ids: List[str], from_user_id: Optional[str] = None | self, user_ids: List[str], from_user_id: Optional[str] = None | ||||
) -> Dict[str, Optional[Mapping[str, JsonDict]]]: | |||||
) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]: | |||||
"""Returns the cross-signing keys for a set of users. | """Returns the cross-signing keys for a set of users. | ||||
Args: | Args: | ||||
@@ -1043,7 +1042,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||||
if from_user_id: | if from_user_id: | ||||
result = cast( | result = cast( | ||||
Dict[str, Optional[Mapping[str, JsonDict]]], | |||||
Dict[str, Optional[Mapping[str, JsonMapping]]], | |||||
await self.db_pool.runInteraction( | await self.db_pool.runInteraction( | ||||
"get_e2e_cross_signing_signatures", | "get_e2e_cross_signing_signatures", | ||||
self._get_e2e_cross_signing_signatures_txn, | self._get_e2e_cross_signing_signatures_txn, | ||||
@@ -24,6 +24,7 @@ from typing import ( | |||||
Dict, | Dict, | ||||
Iterable, | Iterable, | ||||
List, | List, | ||||
Mapping, | |||||
MutableMapping, | MutableMapping, | ||||
Optional, | Optional, | ||||
Set, | Set, | ||||
@@ -1633,7 +1634,7 @@ class EventsWorkerStore(SQLBaseStore): | |||||
self, | self, | ||||
room_id: str, | room_id: str, | ||||
event_ids: Collection[str], | event_ids: Collection[str], | ||||
) -> Dict[str, bool]: | |||||
) -> Mapping[str, bool]: | |||||
"""Helper for have_seen_events | """Helper for have_seen_events | ||||
Returns: | Returns: | ||||
@@ -2325,7 +2326,7 @@ class EventsWorkerStore(SQLBaseStore): | |||||
@cachedList(cached_method_name="is_partial_state_event", list_name="event_ids") | @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids") | ||||
async def get_partial_state_events( | async def get_partial_state_events( | ||||
self, event_ids: Collection[str] | self, event_ids: Collection[str] | ||||
) -> Dict[str, bool]: | |||||
) -> Mapping[str, bool]: | |||||
"""Checks which of the given events have partial state | """Checks which of the given events have partial state | ||||
Args: | Args: | ||||
@@ -16,7 +16,7 @@ | |||||
import itertools | import itertools | ||||
import json | import json | ||||
import logging | import logging | ||||
from typing import Dict, Iterable, Optional, Tuple | |||||
from typing import Dict, Iterable, Mapping, Optional, Tuple | |||||
from canonicaljson import encode_canonical_json | from canonicaljson import encode_canonical_json | ||||
from signedjson.key import decode_verify_key_bytes | from signedjson.key import decode_verify_key_bytes | ||||
@@ -130,7 +130,7 @@ class KeyStore(CacheInvalidationWorkerStore): | |||||
) | ) | ||||
async def get_server_keys_json( | async def get_server_keys_json( | ||||
self, server_name_and_key_ids: Iterable[Tuple[str, str]] | self, server_name_and_key_ids: Iterable[Tuple[str, str]] | ||||
) -> Dict[Tuple[str, str], FetchKeyResult]: | |||||
) -> Mapping[Tuple[str, str], FetchKeyResult]: | |||||
""" | """ | ||||
Args: | Args: | ||||
server_name_and_key_ids: | server_name_and_key_ids: | ||||
@@ -200,7 +200,7 @@ class KeyStore(CacheInvalidationWorkerStore): | |||||
) | ) | ||||
async def get_server_keys_json_for_remote( | async def get_server_keys_json_for_remote( | ||||
self, server_name: str, key_ids: Iterable[str] | self, server_name: str, key_ids: Iterable[str] | ||||
) -> Dict[str, Optional[FetchKeyResultForRemote]]: | |||||
) -> Mapping[str, Optional[FetchKeyResultForRemote]]: | |||||
"""Fetch the cached keys for the given server/key IDs. | """Fetch the cached keys for the given server/key IDs. | ||||
If we have multiple entries for a given key ID, returns the most recent. | If we have multiple entries for a given key ID, returns the most recent. | ||||
@@ -11,7 +11,17 @@ | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast | |||||
from typing import ( | |||||
TYPE_CHECKING, | |||||
Any, | |||||
Dict, | |||||
Iterable, | |||||
List, | |||||
Mapping, | |||||
Optional, | |||||
Tuple, | |||||
cast, | |||||
) | |||||
from synapse.api.presence import PresenceState, UserPresenceState | from synapse.api.presence import PresenceState, UserPresenceState | ||||
from synapse.replication.tcp.streams import PresenceStream | from synapse.replication.tcp.streams import PresenceStream | ||||
@@ -249,7 +259,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) | |||||
) | ) | ||||
async def get_presence_for_users( | async def get_presence_for_users( | ||||
self, user_ids: Iterable[str] | self, user_ids: Iterable[str] | ||||
) -> Dict[str, UserPresenceState]: | |||||
) -> Mapping[str, UserPresenceState]: | |||||
rows = await self.db_pool.simple_select_many_batch( | rows = await self.db_pool.simple_select_many_batch( | ||||
table="presence_stream", | table="presence_stream", | ||||
column="user_id", | column="user_id", | ||||
@@ -216,7 +216,7 @@ class PushRulesWorkerStore( | |||||
@cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") | @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") | ||||
async def bulk_get_push_rules( | async def bulk_get_push_rules( | ||||
self, user_ids: Collection[str] | self, user_ids: Collection[str] | ||||
) -> Dict[str, FilteredPushRules]: | |||||
) -> Mapping[str, FilteredPushRules]: | |||||
if not user_ids: | if not user_ids: | ||||
return {} | return {} | ||||
@@ -43,7 +43,7 @@ from synapse.storage.util.id_generators import ( | |||||
MultiWriterIdGenerator, | MultiWriterIdGenerator, | ||||
StreamIdGenerator, | StreamIdGenerator, | ||||
) | ) | ||||
from synapse.types import JsonDict | |||||
from synapse.types import JsonDict, JsonMapping | |||||
from synapse.util import json_encoder | from synapse.util import json_encoder | ||||
from synapse.util.caches.descriptors import cached, cachedList | from synapse.util.caches.descriptors import cached, cachedList | ||||
from synapse.util.caches.stream_change_cache import StreamChangeCache | from synapse.util.caches.stream_change_cache import StreamChangeCache | ||||
@@ -218,7 +218,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
@cached() | @cached() | ||||
async def _get_receipts_for_user_with_orderings( | async def _get_receipts_for_user_with_orderings( | ||||
self, user_id: str, receipt_type: str | self, user_id: str, receipt_type: str | ||||
) -> JsonDict: | |||||
) -> JsonMapping: | |||||
""" | """ | ||||
Fetch receipts for all rooms that the given user is joined to. | Fetch receipts for all rooms that the given user is joined to. | ||||
@@ -258,7 +258,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
async def get_linearized_receipts_for_rooms( | async def get_linearized_receipts_for_rooms( | ||||
self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None | self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None | ||||
) -> List[dict]: | |||||
) -> List[JsonMapping]: | |||||
"""Get receipts for multiple rooms for sending to clients. | """Get receipts for multiple rooms for sending to clients. | ||||
Args: | Args: | ||||
@@ -287,7 +287,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
async def get_linearized_receipts_for_room( | async def get_linearized_receipts_for_room( | ||||
self, room_id: str, to_key: int, from_key: Optional[int] = None | self, room_id: str, to_key: int, from_key: Optional[int] = None | ||||
) -> Sequence[JsonDict]: | |||||
) -> Sequence[JsonMapping]: | |||||
"""Get receipts for a single room for sending to clients. | """Get receipts for a single room for sending to clients. | ||||
Args: | Args: | ||||
@@ -310,7 +310,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
@cached(tree=True) | @cached(tree=True) | ||||
async def _get_linearized_receipts_for_room( | async def _get_linearized_receipts_for_room( | ||||
self, room_id: str, to_key: int, from_key: Optional[int] = None | self, room_id: str, to_key: int, from_key: Optional[int] = None | ||||
) -> Sequence[JsonDict]: | |||||
) -> Sequence[JsonMapping]: | |||||
"""See get_linearized_receipts_for_room""" | """See get_linearized_receipts_for_room""" | ||||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: | def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: | ||||
@@ -353,7 +353,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
) | ) | ||||
async def _get_linearized_receipts_for_rooms( | async def _get_linearized_receipts_for_rooms( | ||||
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None | self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None | ||||
) -> Dict[str, Sequence[JsonDict]]: | |||||
) -> Mapping[str, Sequence[JsonMapping]]: | |||||
if not room_ids: | if not room_ids: | ||||
return {} | return {} | ||||
@@ -415,7 +415,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||||
) | ) | ||||
async def get_linearized_receipts_for_all_rooms( | async def get_linearized_receipts_for_all_rooms( | ||||
self, to_key: int, from_key: Optional[int] = None | self, to_key: int, from_key: Optional[int] = None | ||||
) -> Mapping[str, JsonDict]: | |||||
) -> Mapping[str, JsonMapping]: | |||||
"""Get receipts for all rooms between two stream_ids, up | """Get receipts for all rooms between two stream_ids, up | ||||
to a limit of the latest 100 read receipts. | to a limit of the latest 100 read receipts. | ||||
@@ -519,7 +519,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||||
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") | @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") | ||||
async def get_applicable_edits( | async def get_applicable_edits( | ||||
self, event_ids: Collection[str] | self, event_ids: Collection[str] | ||||
) -> Dict[str, Optional[EventBase]]: | |||||
) -> Mapping[str, Optional[EventBase]]: | |||||
"""Get the most recent edit (if any) that has happened for the given | """Get the most recent edit (if any) that has happened for the given | ||||
events. | events. | ||||
@@ -605,7 +605,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||||
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids") | @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") | ||||
async def get_thread_summaries( | async def get_thread_summaries( | ||||
self, event_ids: Collection[str] | self, event_ids: Collection[str] | ||||
) -> Dict[str, Optional[Tuple[int, EventBase]]]: | |||||
) -> Mapping[str, Optional[Tuple[int, EventBase]]]: | |||||
"""Get the number of threaded replies and the latest reply (if any) for the given events. | """Get the number of threaded replies and the latest reply (if any) for the given events. | ||||
Args: | Args: | ||||
@@ -779,7 +779,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||||
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids") | @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") | ||||
async def get_threads_participated( | async def get_threads_participated( | ||||
self, event_ids: Collection[str], user_id: str | self, event_ids: Collection[str], user_id: str | ||||
) -> Dict[str, bool]: | |||||
) -> Mapping[str, bool]: | |||||
"""Get whether the requesting user participated in the given threads. | """Get whether the requesting user participated in the given threads. | ||||
This is separate from get_thread_summaries since that can be cached across | This is separate from get_thread_summaries since that can be cached across | ||||
@@ -191,7 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||||
) | ) | ||||
async def get_subset_users_in_room_with_profiles( | async def get_subset_users_in_room_with_profiles( | ||||
self, room_id: str, user_ids: Collection[str] | self, room_id: str, user_ids: Collection[str] | ||||
) -> Dict[str, ProfileInfo]: | |||||
) -> Mapping[str, ProfileInfo]: | |||||
"""Get a mapping from user ID to profile information for a list of users | """Get a mapping from user ID to profile information for a list of users | ||||
in a given room. | in a given room. | ||||
@@ -676,7 +676,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||||
) | ) | ||||
async def _get_rooms_for_users( | async def _get_rooms_for_users( | ||||
self, user_ids: Collection[str] | self, user_ids: Collection[str] | ||||
) -> Dict[str, FrozenSet[str]]: | |||||
) -> Mapping[str, FrozenSet[str]]: | |||||
"""A batched version of `get_rooms_for_user`. | """A batched version of `get_rooms_for_user`. | ||||
Returns: | Returns: | ||||
@@ -881,7 +881,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||||
) | ) | ||||
async def _get_user_ids_from_membership_event_ids( | async def _get_user_ids_from_membership_event_ids( | ||||
self, event_ids: Iterable[str] | self, event_ids: Iterable[str] | ||||
) -> Dict[str, Optional[str]]: | |||||
) -> Mapping[str, Optional[str]]: | |||||
"""For given set of member event_ids check if they point to a join | """For given set of member event_ids check if they point to a join | ||||
event. | event. | ||||
@@ -1191,7 +1191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||||
) | ) | ||||
async def get_membership_from_event_ids( | async def get_membership_from_event_ids( | ||||
self, member_event_ids: Iterable[str] | self, member_event_ids: Iterable[str] | ||||
) -> Dict[str, Optional[EventIdMembership]]: | |||||
) -> Mapping[str, Optional[EventIdMembership]]: | |||||
"""Get user_id and membership of a set of event IDs. | """Get user_id and membership of a set of event IDs. | ||||
Returns: | Returns: | ||||
@@ -14,7 +14,17 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
import collections.abc | import collections.abc | ||||
import logging | import logging | ||||
from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple | |||||
from typing import ( | |||||
TYPE_CHECKING, | |||||
Any, | |||||
Collection, | |||||
Dict, | |||||
Iterable, | |||||
Mapping, | |||||
Optional, | |||||
Set, | |||||
Tuple, | |||||
) | |||||
import attr | import attr | ||||
@@ -372,7 +382,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): | |||||
) | ) | ||||
async def _get_state_group_for_events( | async def _get_state_group_for_events( | ||||
self, event_ids: Collection[str] | self, event_ids: Collection[str] | ||||
) -> Dict[str, int]: | |||||
) -> Mapping[str, int]: | |||||
"""Returns mapping event_id -> state_group. | """Returns mapping event_id -> state_group. | ||||
Raises: | Raises: | ||||
@@ -14,7 +14,7 @@ | |||||
import logging | import logging | ||||
from enum import Enum | from enum import Enum | ||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, cast | |||||
from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Tuple, cast | |||||
import attr | import attr | ||||
from canonicaljson import encode_canonical_json | from canonicaljson import encode_canonical_json | ||||
@@ -210,7 +210,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): | |||||
) | ) | ||||
async def get_destination_retry_timings_batch( | async def get_destination_retry_timings_batch( | ||||
self, destinations: StrCollection | self, destinations: StrCollection | ||||
) -> Dict[str, Optional[DestinationRetryTimings]]: | |||||
) -> Mapping[str, Optional[DestinationRetryTimings]]: | |||||
rows = await self.db_pool.simple_select_many_batch( | rows = await self.db_pool.simple_select_many_batch( | ||||
table="destinations", | table="destinations", | ||||
iterable=destinations, | iterable=destinations, | ||||
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Dict, Iterable | |||||
from typing import Iterable, Mapping | |||||
from synapse.storage.database import LoggingTransaction | from synapse.storage.database import LoggingTransaction | ||||
from synapse.storage.databases.main import CacheInvalidationWorkerStore | from synapse.storage.databases.main import CacheInvalidationWorkerStore | ||||
@@ -40,7 +40,7 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore): | |||||
return bool(result) | return bool(result) | ||||
@cachedList(cached_method_name="is_user_erased", list_name="user_ids") | @cachedList(cached_method_name="is_user_erased", list_name="user_ids") | ||||
async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]: | |||||
async def are_users_erased(self, user_ids: Iterable[str]) -> Mapping[str, bool]: | |||||
""" | """ | ||||
Checks which users in a list have requested erasure | Checks which users in a list have requested erasure | ||||