@@ -0,0 +1 @@ | |||
Improve type hints. |
@@ -37,7 +37,7 @@ from synapse.api.constants import EduTypes, EventContentFields | |||
from synapse.api.errors import SynapseError | |||
from synapse.api.presence import UserPresenceState | |||
from synapse.events import EventBase, relation_from_event | |||
from synapse.types import JsonDict, RoomID, UserID | |||
from synapse.types import JsonDict, JsonMapping, RoomID, UserID | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
@@ -191,7 +191,7 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict) | |||
class FilterCollection: | |||
def __init__(self, hs: "HomeServer", filter_json: JsonDict): | |||
def __init__(self, hs: "HomeServer", filter_json: JsonMapping): | |||
self._filter_json = filter_json | |||
room_filter_json = self._filter_json.get("room", {}) | |||
@@ -219,7 +219,7 @@ class FilterCollection: | |||
def __repr__(self) -> str: | |||
return "<FilterCollection %s>" % (json.dumps(self._filter_json),) | |||
def get_filter_json(self) -> JsonDict: | |||
def get_filter_json(self) -> JsonMapping: | |||
return self._filter_json | |||
def timeline_limit(self) -> int: | |||
@@ -313,7 +313,7 @@ class FilterCollection: | |||
class Filter: | |||
def __init__(self, hs: "HomeServer", filter_json: JsonDict): | |||
def __init__(self, hs: "HomeServer", filter_json: JsonMapping): | |||
self._hs = hs | |||
self._store = hs.get_datastores().main | |||
self.filter_json = filter_json | |||
@@ -64,7 +64,7 @@ from synapse.federation.transport.client import SendJoinResponse | |||
from synapse.http.client import is_unknown_endpoint | |||
from synapse.http.types import QueryParams | |||
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace | |||
from synapse.types import JsonDict, UserID, get_domain_from_id | |||
from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id | |||
from synapse.util.async_helpers import concurrently_execute | |||
from synapse.util.caches.expiringcache import ExpiringCache | |||
from synapse.util.retryutils import NotRetryingDestination | |||
@@ -1704,7 +1704,7 @@ class FederationClient(FederationBase): | |||
async def timestamp_to_event( | |||
self, | |||
*, | |||
destinations: List[str], | |||
destinations: StrCollection, | |||
room_id: str, | |||
timestamp: int, | |||
direction: Direction, | |||
@@ -1538,7 +1538,7 @@ class FederationEventHandler: | |||
logger.exception("Failed to resync device for %s", sender) | |||
async def backfill_event_id( | |||
self, destinations: List[str], room_id: str, event_id: str | |||
self, destinations: StrCollection, room_id: str, event_id: str | |||
) -> PulledPduInfo: | |||
"""Backfill a single event and persist it as a non-outlier which means | |||
we also pull in all of the state and auth events necessary for it. | |||
@@ -13,7 +13,17 @@ | |||
# limitations under the License. | |||
import enum | |||
import logging | |||
from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Collection, | |||
Dict, | |||
FrozenSet, | |||
Iterable, | |||
List, | |||
Mapping, | |||
Optional, | |||
Sequence, | |||
) | |||
import attr | |||
@@ -245,7 +255,7 @@ class RelationsHandler: | |||
async def get_references_for_events( | |||
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() | |||
) -> Dict[str, List[_RelatedEvent]]: | |||
) -> Mapping[str, Sequence[_RelatedEvent]]: | |||
"""Get a list of references to the given events. | |||
Args: | |||
@@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseErro | |||
from synapse.http.server import HttpServer | |||
from synapse.http.servlet import RestServlet, parse_json_object_from_request | |||
from synapse.http.site import SynapseRequest | |||
from synapse.types import JsonDict, UserID | |||
from synapse.types import JsonDict, JsonMapping, UserID | |||
from ._base import client_patterns, set_timeline_upper_limit | |||
@@ -41,7 +41,7 @@ class GetFilterRestServlet(RestServlet): | |||
async def on_GET( | |||
self, request: SynapseRequest, user_id: str, filter_id: str | |||
) -> Tuple[int, JsonDict]: | |||
) -> Tuple[int, JsonMapping]: | |||
target_user = UserID.from_string(user_id) | |||
requester = await self.auth.get_user_by_req(request) | |||
@@ -582,7 +582,7 @@ class StateStorageController: | |||
@trace | |||
@tag_args | |||
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: | |||
async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]: | |||
"""Get current hosts in room based on current state. | |||
Blocks until we have full state for the given room. This only happens for rooms | |||
@@ -25,7 +25,7 @@ from synapse.storage.database import ( | |||
LoggingTransaction, | |||
) | |||
from synapse.storage.engines import PostgresEngine | |||
from synapse.types import JsonDict, UserID | |||
from synapse.types import JsonDict, JsonMapping, UserID | |||
from synapse.util.caches.descriptors import cached | |||
if TYPE_CHECKING: | |||
@@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore): | |||
@cached(num_args=2) | |||
async def get_user_filter( | |||
self, user_id: UserID, filter_id: Union[int, str] | |||
) -> JsonDict: | |||
) -> JsonMapping: | |||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail | |||
# with a coherent error message rather than 500 M_UNKNOWN. | |||
try: | |||
@@ -465,7 +465,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||
@cachedList(cached_method_name="get_references_for_event", list_name="event_ids") | |||
async def get_references_for_events( | |||
self, event_ids: Collection[str] | |||
) -> Mapping[str, Optional[List[_RelatedEvent]]]: | |||
) -> Mapping[str, Optional[Sequence[_RelatedEvent]]]: | |||
"""Get a list of references to the given events. | |||
Args: | |||
@@ -931,7 +931,7 @@ class RelationsWorkerStore(SQLBaseStore): | |||
room_id: str, | |||
limit: int = 5, | |||
from_token: Optional[ThreadsNextBatch] = None, | |||
) -> Tuple[List[str], Optional[ThreadsNextBatch]]: | |||
) -> Tuple[Sequence[str], Optional[ThreadsNextBatch]]: | |||
"""Get a list of thread IDs, ordered by topological ordering of their | |||
latest reply. | |||
@@ -984,7 +984,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||
) | |||
@cached(iterable=True, max_entries=10000) | |||
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: | |||
async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]: | |||
""" | |||
Get current hosts in room based on current state. | |||
@@ -1013,12 +1013,14 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||
# `get_users_in_room` rather than funky SQL. | |||
domains = await self.get_current_hosts_in_room(room_id) | |||
return list(domains) | |||
return tuple(domains) | |||
# For PostgreSQL we can use a regex to pull out the domains from the | |||
# joined users in `current_state_events` via regex. | |||
def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]: | |||
def get_current_hosts_in_room_ordered_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[str, ...]: | |||
# Returns a list of servers currently joined in the room sorted by | |||
# longest in the room first (aka. with the lowest depth). The | |||
# heuristic of sorting by servers who have been in the room the | |||
@@ -1043,7 +1045,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): | |||
""" | |||
txn.execute(sql, (room_id,)) | |||
# `server_domain` will be `NULL` for malformed MXIDs with no colons. | |||
return [d for d, in txn if d is not None] | |||
return tuple(d for d, in txn if d is not None) | |||
return await self.db_pool.runInteraction( | |||
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn | |||
@@ -15,10 +15,10 @@ | |||
import logging | |||
from typing import ( | |||
Any, | |||
Dict, | |||
Generator, | |||
Iterable, | |||
List, | |||
Mapping, | |||
NoReturn, | |||
Optional, | |||
Set, | |||
@@ -96,7 +96,7 @@ class DescriptorTestCase(unittest.TestCase): | |||
self.mock = mock.Mock() | |||
@descriptors.cached(num_args=1) | |||
def fn(self, arg1: int, arg2: int) -> mock.Mock: | |||
def fn(self, arg1: int, arg2: int) -> str: | |||
return self.mock(arg1, arg2) | |||
obj = Cls() | |||
@@ -228,8 +228,9 @@ class DescriptorTestCase(unittest.TestCase): | |||
call_count = 0 | |||
@cached() | |||
def fn(self, arg1: int) -> Optional[Deferred]: | |||
def fn(self, arg1: int) -> Deferred: | |||
self.call_count += 1 | |||
assert self.result is not None | |||
return self.result | |||
obj = Cls() | |||
@@ -401,21 +402,21 @@ class DescriptorTestCase(unittest.TestCase): | |||
self.mock = mock.Mock() | |||
@descriptors.cached(iterable=True) | |||
def fn(self, arg1: int, arg2: int) -> List[str]: | |||
def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]: | |||
return self.mock(arg1, arg2) | |||
obj = Cls() | |||
obj.mock.return_value = ["spam", "eggs"] | |||
obj.mock.return_value = ("spam", "eggs") | |||
r = obj.fn(1, 2) | |||
self.assertEqual(r.result, ["spam", "eggs"]) | |||
self.assertEqual(r.result, ("spam", "eggs")) | |||
obj.mock.assert_called_once_with(1, 2) | |||
obj.mock.reset_mock() | |||
# a call with different params should call the mock again | |||
obj.mock.return_value = ["chips"] | |||
obj.mock.return_value = ("chips",) | |||
r = obj.fn(1, 3) | |||
self.assertEqual(r.result, ["chips"]) | |||
self.assertEqual(r.result, ("chips",)) | |||
obj.mock.assert_called_once_with(1, 3) | |||
obj.mock.reset_mock() | |||
@@ -423,9 +424,9 @@ class DescriptorTestCase(unittest.TestCase): | |||
self.assertEqual(len(obj.fn.cache.cache), 3) | |||
r = obj.fn(1, 2) | |||
self.assertEqual(r.result, ["spam", "eggs"]) | |||
self.assertEqual(r.result, ("spam", "eggs")) | |||
r = obj.fn(1, 3) | |||
self.assertEqual(r.result, ["chips"]) | |||
self.assertEqual(r.result, ("chips",)) | |||
obj.mock.assert_not_called() | |||
def test_cache_iterable_with_sync_exception(self) -> None: | |||
@@ -784,7 +785,9 @@ class CachedListDescriptorTestCase(unittest.TestCase): | |||
pass | |||
@descriptors.cachedList(cached_method_name="fn", list_name="args1") | |||
async def list_fn(self, args1: Iterable[int], arg2: int) -> Dict[int, str]: | |||
async def list_fn( | |||
self, args1: Iterable[int], arg2: int | |||
) -> Mapping[int, str]: | |||
context = current_context() | |||
assert isinstance(context, LoggingContext) | |||
assert context.name == "c1" | |||
@@ -847,11 +850,11 @@ class CachedListDescriptorTestCase(unittest.TestCase): | |||
pass | |||
@descriptors.cachedList(cached_method_name="fn", list_name="args1") | |||
def list_fn(self, args1: List[int]) -> "Deferred[dict]": | |||
def list_fn(self, args1: List[int]) -> "Deferred[Mapping[int, str]]": | |||
return self.mock(args1) | |||
obj = Cls() | |||
deferred_result: "Deferred[dict]" = Deferred() | |||
deferred_result: "Deferred[Mapping[int, str]]" = Deferred() | |||
obj.mock.return_value = deferred_result | |||
# start off several concurrent lookups of the same key | |||
@@ -890,7 +893,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): | |||
pass | |||
@descriptors.cachedList(cached_method_name="fn", list_name="args1") | |||
async def list_fn(self, args1: List[int], arg2: int) -> Dict[int, str]: | |||
async def list_fn(self, args1: List[int], arg2: int) -> Mapping[int, str]: | |||
# we want this to behave like an asynchronous function | |||
await run_on_reactor() | |||
return self.mock(args1, arg2) | |||
@@ -929,7 +932,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): | |||
pass | |||
@cachedList(cached_method_name="fn", list_name="args") | |||
async def list_fn(self, args: List[int]) -> Dict[int, str]: | |||
async def list_fn(self, args: List[int]) -> Mapping[int, str]: | |||
await complete_lookup | |||
return {arg: str(arg) for arg in args} | |||
@@ -964,7 +967,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): | |||
pass | |||
@cachedList(cached_method_name="fn", list_name="args") | |||
async def list_fn(self, args: List[int]) -> Dict[int, str]: | |||
async def list_fn(self, args: List[int]) -> Mapping[int, str]: | |||
await make_deferred_yieldable(complete_lookup) | |||
self.inner_context_was_finished = current_context().finished | |||
return {arg: str(arg) for arg in args} | |||