@@ -0,0 +1 @@ | |||
Improve type hints. |
@@ -27,9 +27,7 @@ from typing import ( | |||
Any, | |||
Awaitable, | |||
Callable, | |||
Collection, | |||
Dict, | |||
Iterable, | |||
List, | |||
NoReturn, | |||
Optional, | |||
@@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_ | |||
from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( | |||
load_legacy_third_party_event_rules, | |||
) | |||
from synapse.types import ISynapseReactor | |||
from synapse.types import ISynapseReactor, StrCollection | |||
from synapse.util import SYNAPSE_VERSION | |||
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries | |||
from synapse.util.daemonize import daemonize_process | |||
@@ -278,7 +276,7 @@ def register_start( | |||
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) | |||
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: | |||
def listen_metrics(bind_addresses: StrCollection, port: int) -> None: | |||
""" | |||
Start Prometheus metrics server. | |||
""" | |||
@@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None: | |||
def listen_manhole( | |||
bind_addresses: Collection[str], | |||
bind_addresses: StrCollection, | |||
port: int, | |||
manhole_settings: ManholeConfig, | |||
manhole_globals: dict, | |||
@@ -339,7 +337,7 @@ def listen_manhole( | |||
def listen_tcp( | |||
bind_addresses: Collection[str], | |||
bind_addresses: StrCollection, | |||
port: int, | |||
factory: ServerFactory, | |||
reactor: IReactorTCP = reactor, | |||
@@ -448,7 +446,7 @@ def listen_http( | |||
def listen_ssl( | |||
bind_addresses: Collection[str], | |||
bind_addresses: StrCollection, | |||
port: int, | |||
factory: ServerFactory, | |||
context_factory: IOpenSSLContextFactory, | |||
@@ -26,7 +26,6 @@ from textwrap import dedent | |||
from typing import ( | |||
Any, | |||
ClassVar, | |||
Collection, | |||
Dict, | |||
Iterable, | |||
Iterator, | |||
@@ -384,7 +383,7 @@ class RootConfig: | |||
config_classes: List[Type[Config]] = [] | |||
def __init__(self, config_files: Collection[str] = ()): | |||
def __init__(self, config_files: StrSequence = ()): | |||
# Capture absolute paths here, so we can reload config after we daemonize. | |||
self.config_files = [os.path.abspath(path) for path in config_files] | |||
@@ -25,7 +25,6 @@ from typing import ( | |||
Iterable, | |||
List, | |||
Optional, | |||
Sequence, | |||
Tuple, | |||
Type, | |||
TypeVar, | |||
@@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta): | |||
def keys(self) -> Iterable[str]: | |||
return self._dict.keys() | |||
def prev_event_ids(self) -> Sequence[str]: | |||
def prev_event_ids(self) -> List[str]: | |||
"""Returns the list of prev event IDs. The order matches the order | |||
specified in the event, though there is no meaning to it. | |||
@@ -553,7 +552,7 @@ class FrozenEventV2(EventBase): | |||
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) | |||
return self._event_id | |||
def prev_event_ids(self) -> Sequence[str]: | |||
def prev_event_ids(self) -> List[str]: | |||
"""Returns the list of prev event IDs. The order matches the order | |||
specified in the event, though there is no meaning to it. | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union | |||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union | |||
import attr | |||
from signedjson.types import SigningKey | |||
@@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event | |||
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict | |||
from synapse.state import StateHandler | |||
from synapse.storage.databases.main import DataStore | |||
from synapse.types import EventID, JsonDict | |||
from synapse.types import EventID, JsonDict, StrCollection | |||
from synapse.types.state import StateFilter | |||
from synapse.util import Clock | |||
from synapse.util.stringutils import random_string | |||
@@ -103,7 +103,7 @@ class EventBuilder: | |||
async def build( | |||
self, | |||
prev_event_ids: Collection[str], | |||
prev_event_ids: StrCollection, | |||
auth_event_ids: Optional[List[str]], | |||
depth: Optional[int] = None, | |||
) -> EventBase: | |||
@@ -136,7 +136,7 @@ class EventBuilder: | |||
format_version = self.room_version.event_format | |||
# The types of auth/prev events changes between event versions. | |||
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]] | |||
prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]] | |||
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] | |||
if format_version == EventFormatVersions.ROOM_V1_V2: | |||
auth_events = await self._store.add_event_hashes(auth_event_ids) | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import collections.abc | |||
from typing import Iterable, List, Type, Union, cast | |||
from typing import List, Type, Union, cast | |||
import jsonschema | |||
from pydantic import Field, StrictBool, StrictStr | |||
@@ -36,7 +36,7 @@ from synapse.events.utils import ( | |||
from synapse.federation.federation_server import server_matches_acl_event | |||
from synapse.http.servlet import validate_json_object | |||
from synapse.rest.models import RequestBodyModel | |||
from synapse.types import EventID, JsonDict, RoomID, UserID | |||
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID | |||
class EventValidator: | |||
@@ -225,7 +225,7 @@ class EventValidator: | |||
self._ensure_state_event(event) | |||
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None: | |||
def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None: | |||
for s in keys: | |||
if s not in d: | |||
raise SynapseError(400, "'%s' not in content" % (s,)) | |||
@@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent | |||
from synapse.http.types import QueryParams | |||
from synapse.logging.context import make_deferred_yieldable, run_in_background | |||
from synapse.logging.opentracing import set_tag, start_active_span, tags | |||
from synapse.types import ISynapseReactor | |||
from synapse.types import ISynapseReactor, StrSequence | |||
from synapse.util import json_decoder | |||
from synapse.util.async_helpers import timeout_deferred | |||
@@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu | |||
# the value actually has to be a List, but List is invariant so we can't specify that | |||
# the entries can either be Lists or bytes. | |||
RawHeaderValue = Union[ | |||
List[str], | |||
StrSequence, | |||
List[bytes], | |||
List[Union[str, bytes]], | |||
Tuple[str, ...], | |||
Tuple[bytes, ...], | |||
Tuple[Union[str, bytes], ...], | |||
] | |||
@@ -18,7 +18,6 @@ import logging | |||
from http import HTTPStatus | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Iterable, | |||
List, | |||
Mapping, | |||
Optional, | |||
@@ -38,7 +37,7 @@ from twisted.web.server import Request | |||
from synapse.api.errors import Codes, SynapseError | |||
from synapse.http import redact_uri | |||
from synapse.http.server import HttpServer | |||
from synapse.types import JsonDict, RoomAlias, RoomID | |||
from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection | |||
from synapse.util import json_decoder | |||
if TYPE_CHECKING: | |||
@@ -340,7 +339,7 @@ def parse_string( | |||
name: str, | |||
default: str, | |||
*, | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> str: | |||
... | |||
@@ -352,7 +351,7 @@ def parse_string( | |||
name: str, | |||
*, | |||
required: Literal[True], | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> str: | |||
... | |||
@@ -365,7 +364,7 @@ def parse_string( | |||
*, | |||
default: Optional[str] = None, | |||
required: bool = False, | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> Optional[str]: | |||
... | |||
@@ -376,7 +375,7 @@ def parse_string( | |||
name: str, | |||
default: Optional[str] = None, | |||
required: bool = False, | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> Optional[str]: | |||
""" | |||
@@ -485,7 +484,7 @@ def parse_enum( | |||
def _parse_string_value( | |||
value: bytes, | |||
allowed_values: Optional[Iterable[str]], | |||
allowed_values: Optional[StrCollection], | |||
name: str, | |||
encoding: str, | |||
) -> str: | |||
@@ -511,7 +510,7 @@ def parse_strings_from_args( | |||
args: Mapping[bytes, Sequence[bytes]], | |||
name: str, | |||
*, | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> Optional[List[str]]: | |||
... | |||
@@ -523,7 +522,7 @@ def parse_strings_from_args( | |||
name: str, | |||
default: List[str], | |||
*, | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> List[str]: | |||
... | |||
@@ -535,7 +534,7 @@ def parse_strings_from_args( | |||
name: str, | |||
*, | |||
required: Literal[True], | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> List[str]: | |||
... | |||
@@ -548,7 +547,7 @@ def parse_strings_from_args( | |||
default: Optional[List[str]] = None, | |||
*, | |||
required: bool = False, | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> Optional[List[str]]: | |||
... | |||
@@ -559,7 +558,7 @@ def parse_strings_from_args( | |||
name: str, | |||
default: Optional[List[str]] = None, | |||
required: bool = False, | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> Optional[List[str]]: | |||
""" | |||
@@ -610,7 +609,7 @@ def parse_string_from_args( | |||
name: str, | |||
default: Optional[str] = None, | |||
*, | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> Optional[str]: | |||
... | |||
@@ -623,7 +622,7 @@ def parse_string_from_args( | |||
default: Optional[str] = None, | |||
*, | |||
required: Literal[True], | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> str: | |||
... | |||
@@ -635,7 +634,7 @@ def parse_string_from_args( | |||
name: str, | |||
default: Optional[str] = None, | |||
required: bool = False, | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> Optional[str]: | |||
... | |||
@@ -646,7 +645,7 @@ def parse_string_from_args( | |||
name: str, | |||
default: Optional[str] = None, | |||
required: bool = False, | |||
allowed_values: Optional[Iterable[str]] = None, | |||
allowed_values: Optional[StrCollection] = None, | |||
encoding: str = "ascii", | |||
) -> Optional[str]: | |||
""" | |||
@@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request( | |||
return validate_json_object(content, model_type) | |||
def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: | |||
def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None: | |||
absent = [] | |||
for k in required: | |||
if k not in body: | |||
@@ -25,7 +25,6 @@ from typing import ( | |||
Iterable, | |||
Mapping, | |||
Optional, | |||
Sequence, | |||
Set, | |||
Tuple, | |||
Type, | |||
@@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401 | |||
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager | |||
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest | |||
from synapse.metrics._types import Collector | |||
from synapse.types import StrSequence | |||
from synapse.util import SYNAPSE_VERSION | |||
logger = logging.getLogger(__name__) | |||
@@ -81,7 +81,7 @@ class LaterGauge(Collector): | |||
name: str | |||
desc: str | |||
labels: Optional[Sequence[str]] = attr.ib(hash=False) | |||
labels: Optional[StrSequence] = attr.ib(hash=False) | |||
# callback: should either return a value (if there are no labels for this metric), | |||
# or dict mapping from a label tuple to a value | |||
caller: Callable[ | |||
@@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector): | |||
self, | |||
name: str, | |||
desc: str, | |||
labels: Sequence[str], | |||
sub_metrics: Sequence[str], | |||
labels: StrSequence, | |||
sub_metrics: StrSequence, | |||
): | |||
self.name = name | |||
self.desc = desc | |||
@@ -104,7 +104,7 @@ class _NotifierUserStream: | |||
def __init__( | |||
self, | |||
user_id: str, | |||
rooms: Collection[str], | |||
rooms: StrCollection, | |||
current_token: StreamToken, | |||
time_now_ms: int, | |||
): | |||
@@ -457,7 +457,7 @@ class Notifier: | |||
stream_key: str, | |||
new_token: Union[int, RoomStreamToken], | |||
users: Optional[Collection[Union[str, UserID]]] = None, | |||
rooms: Optional[Collection[str]] = None, | |||
rooms: Optional[StrCollection] = None, | |||
) -> None: | |||
"""Used to inform listeners that something has happened event wise. | |||
@@ -529,7 +529,7 @@ class Notifier: | |||
user_id: str, | |||
timeout: int, | |||
callback: Callable[[StreamToken, StreamToken], Awaitable[T]], | |||
room_ids: Optional[Collection[str]] = None, | |||
room_ids: Optional[StrCollection] = None, | |||
from_token: StreamToken = StreamToken.START, | |||
) -> T: | |||
"""Wait until the callback returns a non empty response or the | |||
@@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, | |||
from synapse.api.errors import InteractiveAuthIncompleteError | |||
from synapse.api.urls import CLIENT_API_PREFIX | |||
from synapse.types import JsonDict | |||
from synapse.types import JsonDict, StrCollection | |||
logger = logging.getLogger(__name__) | |||
def client_patterns( | |||
path_regex: str, | |||
releases: Iterable[str] = ("r0", "v3"), | |||
releases: StrCollection = ("r0", "v3"), | |||
unstable: bool = True, | |||
v1: bool = False, | |||
) -> Iterable[Pattern]: | |||
@@ -20,7 +20,6 @@ from typing import ( | |||
Any, | |||
Awaitable, | |||
Callable, | |||
Collection, | |||
DefaultDict, | |||
Dict, | |||
FrozenSet, | |||
@@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace | |||
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet | |||
from synapse.state import v1, v2 | |||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour | |||
from synapse.types import StateMap | |||
from synapse.types import StateMap, StrCollection | |||
from synapse.types.state import StateFilter | |||
from synapse.util.async_helpers import Linearizer | |||
from synapse.util.caches.expiringcache import ExpiringCache | |||
@@ -197,7 +196,7 @@ class StateHandler: | |||
async def compute_state_after_events( | |||
self, | |||
room_id: str, | |||
event_ids: Collection[str], | |||
event_ids: StrCollection, | |||
state_filter: Optional[StateFilter] = None, | |||
await_full_state: bool = True, | |||
) -> StateMap[str]: | |||
@@ -231,7 +230,7 @@ class StateHandler: | |||
return await ret.get_state(self._state_storage_controller, state_filter) | |||
async def get_current_user_ids_in_room( | |||
self, room_id: str, latest_event_ids: Collection[str] | |||
self, room_id: str, latest_event_ids: StrCollection | |||
) -> Set[str]: | |||
""" | |||
Get the users IDs who are currently in a room. | |||
@@ -256,7 +255,7 @@ class StateHandler: | |||
return await self.store.get_joined_user_ids_from_state(room_id, state) | |||
async def get_hosts_in_room_at_events( | |||
self, room_id: str, event_ids: Collection[str] | |||
self, room_id: str, event_ids: StrCollection | |||
) -> FrozenSet[str]: | |||
"""Get the hosts that were in a room at the given event ids | |||
@@ -470,7 +469,7 @@ class StateHandler: | |||
@trace | |||
@measure_func() | |||
async def resolve_state_groups_for_events( | |||
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True | |||
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True | |||
) -> _StateCacheEntry: | |||
"""Given a list of event_ids this method fetches the state at each | |||
event, resolves conflicts between them and returns them. | |||
@@ -882,7 +881,7 @@ class StateResolutionStore: | |||
store: "DataStore" | |||
def get_events( | |||
self, event_ids: Collection[str], allow_rejected: bool = False | |||
self, event_ids: StrCollection, allow_rejected: bool = False | |||
) -> Awaitable[Dict[str, EventBase]]: | |||
"""Get events from the database | |||
@@ -17,7 +17,6 @@ import logging | |||
from typing import ( | |||
Awaitable, | |||
Callable, | |||
Collection, | |||
Dict, | |||
Iterable, | |||
List, | |||
@@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes | |||
from synapse.api.errors import AuthError | |||
from synapse.api.room_versions import RoomVersion | |||
from synapse.events import EventBase | |||
from synapse.types import MutableStateMap, StateMap | |||
from synapse.types import MutableStateMap, StateMap, StrCollection | |||
logger = logging.getLogger(__name__) | |||
@@ -45,7 +44,7 @@ async def resolve_events_with_store( | |||
room_version: RoomVersion, | |||
state_sets: Sequence[StateMap[str]], | |||
event_map: Optional[Dict[str, EventBase]], | |||
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]], | |||
state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]], | |||
) -> StateMap[str]: | |||
""" | |||
Args: | |||
@@ -19,7 +19,6 @@ from typing import ( | |||
Any, | |||
Awaitable, | |||
Callable, | |||
Collection, | |||
Dict, | |||
Generator, | |||
Iterable, | |||
@@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes | |||
from synapse.api.errors import AuthError | |||
from synapse.api.room_versions import RoomVersion | |||
from synapse.events import EventBase | |||
from synapse.types import MutableStateMap, StateMap | |||
from synapse.types import MutableStateMap, StateMap, StrCollection | |||
logger = logging.getLogger(__name__) | |||
@@ -56,7 +55,7 @@ class StateResolutionStore(Protocol): | |||
# This is usually synapse.state.StateResolutionStore, but it's replaced with a | |||
# TestStateResolutionStore in tests. | |||
def get_events( | |||
self, event_ids: Collection[str], allow_rejected: bool = False | |||
self, event_ids: StrCollection, allow_rejected: bool = False | |||
) -> Awaitable[Dict[str, EventBase]]: | |||
... | |||
@@ -366,7 +365,7 @@ async def _get_auth_chain_difference( | |||
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) | |||
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) | |||
auth_difference_unpersisted_part: Collection[str] = union - intersection | |||
auth_difference_unpersisted_part: StrCollection = union - intersection | |||
else: | |||
auth_difference_unpersisted_part = () | |||
state_sets_ids = [set(state_set.values()) for state_set in state_sets] | |||
@@ -47,7 +47,7 @@ from synapse.storage.database import ( | |||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | |||
from synapse.storage.databases.main.signatures import SignatureWorkerStore | |||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine | |||
from synapse.types import JsonDict, StrCollection | |||
from synapse.types import JsonDict, StrCollection, StrSequence | |||
from synapse.util import json_encoder | |||
from synapse.util.caches.descriptors import cached | |||
from synapse.util.caches.lrucache import LruCache | |||
@@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||
) | |||
@cached(max_entries=5000, iterable=True) | |||
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]: | |||
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence: | |||
return await self.db_pool.simple_select_onecol( | |||
table="event_forward_extremities", | |||
keyvalues={"room_id": room_id}, | |||
@@ -36,7 +36,7 @@ from synapse.events.utils import prune_event | |||
from synapse.logging.opentracing import trace | |||
from synapse.storage.controllers import StorageControllers | |||
from synapse.storage.databases.main import DataStore | |||
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id | |||
from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id | |||
from synapse.types.state import StateFilter | |||
from synapse.util import Clock | |||
@@ -150,12 +150,12 @@ async def filter_events_for_client( | |||
async def filter_event_for_clients_with_state( | |||
store: DataStore, | |||
user_ids: Collection[str], | |||
user_ids: StrCollection, | |||
event: EventBase, | |||
context: EventContext, | |||
is_peeking: bool = False, | |||
filter_send_to_client: bool = True, | |||
) -> Collection[str]: | |||
) -> StrCollection: | |||
""" | |||
Checks to see if an event is visible to the users in the list at the time of | |||
the event. | |||