@@ -0,0 +1 @@ | |||||
Improve type hints. |
@@ -27,9 +27,7 @@ from typing import ( | |||||
Any, | Any, | ||||
Awaitable, | Awaitable, | ||||
Callable, | Callable, | ||||
Collection, | |||||
Dict, | Dict, | ||||
Iterable, | |||||
List, | List, | ||||
NoReturn, | NoReturn, | ||||
Optional, | Optional, | ||||
@@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_ | |||||
from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( | from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( | ||||
load_legacy_third_party_event_rules, | load_legacy_third_party_event_rules, | ||||
) | ) | ||||
from synapse.types import ISynapseReactor | |||||
from synapse.types import ISynapseReactor, StrCollection | |||||
from synapse.util import SYNAPSE_VERSION | from synapse.util import SYNAPSE_VERSION | ||||
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries | from synapse.util.caches.lrucache import setup_expire_lru_cache_entries | ||||
from synapse.util.daemonize import daemonize_process | from synapse.util.daemonize import daemonize_process | ||||
@@ -278,7 +276,7 @@ def register_start( | |||||
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) | reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper())) | ||||
def listen_metrics(bind_addresses: Iterable[str], port: int) -> None: | |||||
def listen_metrics(bind_addresses: StrCollection, port: int) -> None: | |||||
""" | """ | ||||
Start Prometheus metrics server. | Start Prometheus metrics server. | ||||
""" | """ | ||||
@@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None: | |||||
def listen_manhole( | def listen_manhole( | ||||
bind_addresses: Collection[str], | |||||
bind_addresses: StrCollection, | |||||
port: int, | port: int, | ||||
manhole_settings: ManholeConfig, | manhole_settings: ManholeConfig, | ||||
manhole_globals: dict, | manhole_globals: dict, | ||||
@@ -339,7 +337,7 @@ def listen_manhole( | |||||
def listen_tcp( | def listen_tcp( | ||||
bind_addresses: Collection[str], | |||||
bind_addresses: StrCollection, | |||||
port: int, | port: int, | ||||
factory: ServerFactory, | factory: ServerFactory, | ||||
reactor: IReactorTCP = reactor, | reactor: IReactorTCP = reactor, | ||||
@@ -448,7 +446,7 @@ def listen_http( | |||||
def listen_ssl( | def listen_ssl( | ||||
bind_addresses: Collection[str], | |||||
bind_addresses: StrCollection, | |||||
port: int, | port: int, | ||||
factory: ServerFactory, | factory: ServerFactory, | ||||
context_factory: IOpenSSLContextFactory, | context_factory: IOpenSSLContextFactory, | ||||
@@ -26,7 +26,6 @@ from textwrap import dedent | |||||
from typing import ( | from typing import ( | ||||
Any, | Any, | ||||
ClassVar, | ClassVar, | ||||
Collection, | |||||
Dict, | Dict, | ||||
Iterable, | Iterable, | ||||
Iterator, | Iterator, | ||||
@@ -384,7 +383,7 @@ class RootConfig: | |||||
config_classes: List[Type[Config]] = [] | config_classes: List[Type[Config]] = [] | ||||
def __init__(self, config_files: Collection[str] = ()): | |||||
def __init__(self, config_files: StrSequence = ()): | |||||
# Capture absolute paths here, so we can reload config after we daemonize. | # Capture absolute paths here, so we can reload config after we daemonize. | ||||
self.config_files = [os.path.abspath(path) for path in config_files] | self.config_files = [os.path.abspath(path) for path in config_files] | ||||
@@ -25,7 +25,6 @@ from typing import ( | |||||
Iterable, | Iterable, | ||||
List, | List, | ||||
Optional, | Optional, | ||||
Sequence, | |||||
Tuple, | Tuple, | ||||
Type, | Type, | ||||
TypeVar, | TypeVar, | ||||
@@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta): | |||||
def keys(self) -> Iterable[str]: | def keys(self) -> Iterable[str]: | ||||
return self._dict.keys() | return self._dict.keys() | ||||
def prev_event_ids(self) -> Sequence[str]: | |||||
def prev_event_ids(self) -> List[str]: | |||||
"""Returns the list of prev event IDs. The order matches the order | """Returns the list of prev event IDs. The order matches the order | ||||
specified in the event, though there is no meaning to it. | specified in the event, though there is no meaning to it. | ||||
@@ -553,7 +552,7 @@ class FrozenEventV2(EventBase): | |||||
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) | self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) | ||||
return self._event_id | return self._event_id | ||||
def prev_event_ids(self) -> Sequence[str]: | |||||
def prev_event_ids(self) -> List[str]: | |||||
"""Returns the list of prev event IDs. The order matches the order | """Returns the list of prev event IDs. The order matches the order | ||||
specified in the event, though there is no meaning to it. | specified in the event, though there is no meaning to it. | ||||
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
import logging | import logging | ||||
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union | |||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union | |||||
import attr | import attr | ||||
from signedjson.types import SigningKey | from signedjson.types import SigningKey | ||||
@@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event | |||||
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict | from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict | ||||
from synapse.state import StateHandler | from synapse.state import StateHandler | ||||
from synapse.storage.databases.main import DataStore | from synapse.storage.databases.main import DataStore | ||||
from synapse.types import EventID, JsonDict | |||||
from synapse.types import EventID, JsonDict, StrCollection | |||||
from synapse.types.state import StateFilter | from synapse.types.state import StateFilter | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from synapse.util.stringutils import random_string | from synapse.util.stringutils import random_string | ||||
@@ -103,7 +103,7 @@ class EventBuilder: | |||||
async def build( | async def build( | ||||
self, | self, | ||||
prev_event_ids: Collection[str], | |||||
prev_event_ids: StrCollection, | |||||
auth_event_ids: Optional[List[str]], | auth_event_ids: Optional[List[str]], | ||||
depth: Optional[int] = None, | depth: Optional[int] = None, | ||||
) -> EventBase: | ) -> EventBase: | ||||
@@ -136,7 +136,7 @@ class EventBuilder: | |||||
format_version = self.room_version.event_format | format_version = self.room_version.event_format | ||||
# The types of auth/prev events changes between event versions. | # The types of auth/prev events changes between event versions. | ||||
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]] | |||||
prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]] | |||||
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] | auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] | ||||
if format_version == EventFormatVersions.ROOM_V1_V2: | if format_version == EventFormatVersions.ROOM_V1_V2: | ||||
auth_events = await self._store.add_event_hashes(auth_event_ids) | auth_events = await self._store.add_event_hashes(auth_event_ids) | ||||
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
import collections.abc | import collections.abc | ||||
from typing import Iterable, List, Type, Union, cast | |||||
from typing import List, Type, Union, cast | |||||
import jsonschema | import jsonschema | ||||
from pydantic import Field, StrictBool, StrictStr | from pydantic import Field, StrictBool, StrictStr | ||||
@@ -36,7 +36,7 @@ from synapse.events.utils import ( | |||||
from synapse.federation.federation_server import server_matches_acl_event | from synapse.federation.federation_server import server_matches_acl_event | ||||
from synapse.http.servlet import validate_json_object | from synapse.http.servlet import validate_json_object | ||||
from synapse.rest.models import RequestBodyModel | from synapse.rest.models import RequestBodyModel | ||||
from synapse.types import EventID, JsonDict, RoomID, UserID | |||||
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID | |||||
class EventValidator: | class EventValidator: | ||||
@@ -225,7 +225,7 @@ class EventValidator: | |||||
self._ensure_state_event(event) | self._ensure_state_event(event) | ||||
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None: | |||||
def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None: | |||||
for s in keys: | for s in keys: | ||||
if s not in d: | if s not in d: | ||||
raise SynapseError(400, "'%s' not in content" % (s,)) | raise SynapseError(400, "'%s' not in content" % (s,)) | ||||
@@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent | |||||
from synapse.http.types import QueryParams | from synapse.http.types import QueryParams | ||||
from synapse.logging.context import make_deferred_yieldable, run_in_background | from synapse.logging.context import make_deferred_yieldable, run_in_background | ||||
from synapse.logging.opentracing import set_tag, start_active_span, tags | from synapse.logging.opentracing import set_tag, start_active_span, tags | ||||
from synapse.types import ISynapseReactor | |||||
from synapse.types import ISynapseReactor, StrSequence | |||||
from synapse.util import json_decoder | from synapse.util import json_decoder | ||||
from synapse.util.async_helpers import timeout_deferred | from synapse.util.async_helpers import timeout_deferred | ||||
@@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu | |||||
# the value actually has to be a List, but List is invariant so we can't specify that | # the value actually has to be a List, but List is invariant so we can't specify that | ||||
# the entries can either be Lists or bytes. | # the entries can either be Lists or bytes. | ||||
RawHeaderValue = Union[ | RawHeaderValue = Union[ | ||||
List[str], | |||||
StrSequence, | |||||
List[bytes], | List[bytes], | ||||
List[Union[str, bytes]], | List[Union[str, bytes]], | ||||
Tuple[str, ...], | |||||
Tuple[bytes, ...], | Tuple[bytes, ...], | ||||
Tuple[Union[str, bytes], ...], | Tuple[Union[str, bytes], ...], | ||||
] | ] | ||||
@@ -18,7 +18,6 @@ import logging | |||||
from http import HTTPStatus | from http import HTTPStatus | ||||
from typing import ( | from typing import ( | ||||
TYPE_CHECKING, | TYPE_CHECKING, | ||||
Iterable, | |||||
List, | List, | ||||
Mapping, | Mapping, | ||||
Optional, | Optional, | ||||
@@ -38,7 +37,7 @@ from twisted.web.server import Request | |||||
from synapse.api.errors import Codes, SynapseError | from synapse.api.errors import Codes, SynapseError | ||||
from synapse.http import redact_uri | from synapse.http import redact_uri | ||||
from synapse.http.server import HttpServer | from synapse.http.server import HttpServer | ||||
from synapse.types import JsonDict, RoomAlias, RoomID | |||||
from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection | |||||
from synapse.util import json_decoder | from synapse.util import json_decoder | ||||
if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
@@ -340,7 +339,7 @@ def parse_string( | |||||
name: str, | name: str, | ||||
default: str, | default: str, | ||||
*, | *, | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> str: | ) -> str: | ||||
... | ... | ||||
@@ -352,7 +351,7 @@ def parse_string( | |||||
name: str, | name: str, | ||||
*, | *, | ||||
required: Literal[True], | required: Literal[True], | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> str: | ) -> str: | ||||
... | ... | ||||
@@ -365,7 +364,7 @@ def parse_string( | |||||
*, | *, | ||||
default: Optional[str] = None, | default: Optional[str] = None, | ||||
required: bool = False, | required: bool = False, | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> Optional[str]: | ) -> Optional[str]: | ||||
... | ... | ||||
@@ -376,7 +375,7 @@ def parse_string( | |||||
name: str, | name: str, | ||||
default: Optional[str] = None, | default: Optional[str] = None, | ||||
required: bool = False, | required: bool = False, | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> Optional[str]: | ) -> Optional[str]: | ||||
""" | """ | ||||
@@ -485,7 +484,7 @@ def parse_enum( | |||||
def _parse_string_value( | def _parse_string_value( | ||||
value: bytes, | value: bytes, | ||||
allowed_values: Optional[Iterable[str]], | |||||
allowed_values: Optional[StrCollection], | |||||
name: str, | name: str, | ||||
encoding: str, | encoding: str, | ||||
) -> str: | ) -> str: | ||||
@@ -511,7 +510,7 @@ def parse_strings_from_args( | |||||
args: Mapping[bytes, Sequence[bytes]], | args: Mapping[bytes, Sequence[bytes]], | ||||
name: str, | name: str, | ||||
*, | *, | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> Optional[List[str]]: | ) -> Optional[List[str]]: | ||||
... | ... | ||||
@@ -523,7 +522,7 @@ def parse_strings_from_args( | |||||
name: str, | name: str, | ||||
default: List[str], | default: List[str], | ||||
*, | *, | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> List[str]: | ) -> List[str]: | ||||
... | ... | ||||
@@ -535,7 +534,7 @@ def parse_strings_from_args( | |||||
name: str, | name: str, | ||||
*, | *, | ||||
required: Literal[True], | required: Literal[True], | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> List[str]: | ) -> List[str]: | ||||
... | ... | ||||
@@ -548,7 +547,7 @@ def parse_strings_from_args( | |||||
default: Optional[List[str]] = None, | default: Optional[List[str]] = None, | ||||
*, | *, | ||||
required: bool = False, | required: bool = False, | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> Optional[List[str]]: | ) -> Optional[List[str]]: | ||||
... | ... | ||||
@@ -559,7 +558,7 @@ def parse_strings_from_args( | |||||
name: str, | name: str, | ||||
default: Optional[List[str]] = None, | default: Optional[List[str]] = None, | ||||
required: bool = False, | required: bool = False, | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> Optional[List[str]]: | ) -> Optional[List[str]]: | ||||
""" | """ | ||||
@@ -610,7 +609,7 @@ def parse_string_from_args( | |||||
name: str, | name: str, | ||||
default: Optional[str] = None, | default: Optional[str] = None, | ||||
*, | *, | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> Optional[str]: | ) -> Optional[str]: | ||||
... | ... | ||||
@@ -623,7 +622,7 @@ def parse_string_from_args( | |||||
default: Optional[str] = None, | default: Optional[str] = None, | ||||
*, | *, | ||||
required: Literal[True], | required: Literal[True], | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> str: | ) -> str: | ||||
... | ... | ||||
@@ -635,7 +634,7 @@ def parse_string_from_args( | |||||
name: str, | name: str, | ||||
default: Optional[str] = None, | default: Optional[str] = None, | ||||
required: bool = False, | required: bool = False, | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> Optional[str]: | ) -> Optional[str]: | ||||
... | ... | ||||
@@ -646,7 +645,7 @@ def parse_string_from_args( | |||||
name: str, | name: str, | ||||
default: Optional[str] = None, | default: Optional[str] = None, | ||||
required: bool = False, | required: bool = False, | ||||
allowed_values: Optional[Iterable[str]] = None, | |||||
allowed_values: Optional[StrCollection] = None, | |||||
encoding: str = "ascii", | encoding: str = "ascii", | ||||
) -> Optional[str]: | ) -> Optional[str]: | ||||
""" | """ | ||||
@@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request( | |||||
return validate_json_object(content, model_type) | return validate_json_object(content, model_type) | ||||
def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: | |||||
def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None: | |||||
absent = [] | absent = [] | ||||
for k in required: | for k in required: | ||||
if k not in body: | if k not in body: | ||||
@@ -25,7 +25,6 @@ from typing import ( | |||||
Iterable, | Iterable, | ||||
Mapping, | Mapping, | ||||
Optional, | Optional, | ||||
Sequence, | |||||
Set, | Set, | ||||
Tuple, | Tuple, | ||||
Type, | Type, | ||||
@@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401 | |||||
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager | from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager | ||||
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest | from synapse.metrics._twisted_exposition import MetricsResource, generate_latest | ||||
from synapse.metrics._types import Collector | from synapse.metrics._types import Collector | ||||
from synapse.types import StrSequence | |||||
from synapse.util import SYNAPSE_VERSION | from synapse.util import SYNAPSE_VERSION | ||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
@@ -81,7 +81,7 @@ class LaterGauge(Collector): | |||||
name: str | name: str | ||||
desc: str | desc: str | ||||
labels: Optional[Sequence[str]] = attr.ib(hash=False) | |||||
labels: Optional[StrSequence] = attr.ib(hash=False) | |||||
# callback: should either return a value (if there are no labels for this metric), | # callback: should either return a value (if there are no labels for this metric), | ||||
# or dict mapping from a label tuple to a value | # or dict mapping from a label tuple to a value | ||||
caller: Callable[ | caller: Callable[ | ||||
@@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector): | |||||
self, | self, | ||||
name: str, | name: str, | ||||
desc: str, | desc: str, | ||||
labels: Sequence[str], | |||||
sub_metrics: Sequence[str], | |||||
labels: StrSequence, | |||||
sub_metrics: StrSequence, | |||||
): | ): | ||||
self.name = name | self.name = name | ||||
self.desc = desc | self.desc = desc | ||||
@@ -104,7 +104,7 @@ class _NotifierUserStream: | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
user_id: str, | user_id: str, | ||||
rooms: Collection[str], | |||||
rooms: StrCollection, | |||||
current_token: StreamToken, | current_token: StreamToken, | ||||
time_now_ms: int, | time_now_ms: int, | ||||
): | ): | ||||
@@ -457,7 +457,7 @@ class Notifier: | |||||
stream_key: str, | stream_key: str, | ||||
new_token: Union[int, RoomStreamToken], | new_token: Union[int, RoomStreamToken], | ||||
users: Optional[Collection[Union[str, UserID]]] = None, | users: Optional[Collection[Union[str, UserID]]] = None, | ||||
rooms: Optional[Collection[str]] = None, | |||||
rooms: Optional[StrCollection] = None, | |||||
) -> None: | ) -> None: | ||||
"""Used to inform listeners that something has happened event wise. | """Used to inform listeners that something has happened event wise. | ||||
@@ -529,7 +529,7 @@ class Notifier: | |||||
user_id: str, | user_id: str, | ||||
timeout: int, | timeout: int, | ||||
callback: Callable[[StreamToken, StreamToken], Awaitable[T]], | callback: Callable[[StreamToken, StreamToken], Awaitable[T]], | ||||
room_ids: Optional[Collection[str]] = None, | |||||
room_ids: Optional[StrCollection] = None, | |||||
from_token: StreamToken = StreamToken.START, | from_token: StreamToken = StreamToken.START, | ||||
) -> T: | ) -> T: | ||||
"""Wait until the callback returns a non empty response or the | """Wait until the callback returns a non empty response or the | ||||
@@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, | |||||
from synapse.api.errors import InteractiveAuthIncompleteError | from synapse.api.errors import InteractiveAuthIncompleteError | ||||
from synapse.api.urls import CLIENT_API_PREFIX | from synapse.api.urls import CLIENT_API_PREFIX | ||||
from synapse.types import JsonDict | |||||
from synapse.types import JsonDict, StrCollection | |||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
def client_patterns( | def client_patterns( | ||||
path_regex: str, | path_regex: str, | ||||
releases: Iterable[str] = ("r0", "v3"), | |||||
releases: StrCollection = ("r0", "v3"), | |||||
unstable: bool = True, | unstable: bool = True, | ||||
v1: bool = False, | v1: bool = False, | ||||
) -> Iterable[Pattern]: | ) -> Iterable[Pattern]: | ||||
@@ -20,7 +20,6 @@ from typing import ( | |||||
Any, | Any, | ||||
Awaitable, | Awaitable, | ||||
Callable, | Callable, | ||||
Collection, | |||||
DefaultDict, | DefaultDict, | ||||
Dict, | Dict, | ||||
FrozenSet, | FrozenSet, | ||||
@@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace | |||||
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet | from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet | ||||
from synapse.state import v1, v2 | from synapse.state import v1, v2 | ||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour | from synapse.storage.databases.main.events_worker import EventRedactBehaviour | ||||
from synapse.types import StateMap | |||||
from synapse.types import StateMap, StrCollection | |||||
from synapse.types.state import StateFilter | from synapse.types.state import StateFilter | ||||
from synapse.util.async_helpers import Linearizer | from synapse.util.async_helpers import Linearizer | ||||
from synapse.util.caches.expiringcache import ExpiringCache | from synapse.util.caches.expiringcache import ExpiringCache | ||||
@@ -197,7 +196,7 @@ class StateHandler: | |||||
async def compute_state_after_events( | async def compute_state_after_events( | ||||
self, | self, | ||||
room_id: str, | room_id: str, | ||||
event_ids: Collection[str], | |||||
event_ids: StrCollection, | |||||
state_filter: Optional[StateFilter] = None, | state_filter: Optional[StateFilter] = None, | ||||
await_full_state: bool = True, | await_full_state: bool = True, | ||||
) -> StateMap[str]: | ) -> StateMap[str]: | ||||
@@ -231,7 +230,7 @@ class StateHandler: | |||||
return await ret.get_state(self._state_storage_controller, state_filter) | return await ret.get_state(self._state_storage_controller, state_filter) | ||||
async def get_current_user_ids_in_room( | async def get_current_user_ids_in_room( | ||||
self, room_id: str, latest_event_ids: Collection[str] | |||||
self, room_id: str, latest_event_ids: StrCollection | |||||
) -> Set[str]: | ) -> Set[str]: | ||||
""" | """ | ||||
Get the users IDs who are currently in a room. | Get the users IDs who are currently in a room. | ||||
@@ -256,7 +255,7 @@ class StateHandler: | |||||
return await self.store.get_joined_user_ids_from_state(room_id, state) | return await self.store.get_joined_user_ids_from_state(room_id, state) | ||||
async def get_hosts_in_room_at_events( | async def get_hosts_in_room_at_events( | ||||
self, room_id: str, event_ids: Collection[str] | |||||
self, room_id: str, event_ids: StrCollection | |||||
) -> FrozenSet[str]: | ) -> FrozenSet[str]: | ||||
"""Get the hosts that were in a room at the given event ids | """Get the hosts that were in a room at the given event ids | ||||
@@ -470,7 +469,7 @@ class StateHandler: | |||||
@trace | @trace | ||||
@measure_func() | @measure_func() | ||||
async def resolve_state_groups_for_events( | async def resolve_state_groups_for_events( | ||||
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True | |||||
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True | |||||
) -> _StateCacheEntry: | ) -> _StateCacheEntry: | ||||
"""Given a list of event_ids this method fetches the state at each | """Given a list of event_ids this method fetches the state at each | ||||
event, resolves conflicts between them and returns them. | event, resolves conflicts between them and returns them. | ||||
@@ -882,7 +881,7 @@ class StateResolutionStore: | |||||
store: "DataStore" | store: "DataStore" | ||||
def get_events( | def get_events( | ||||
self, event_ids: Collection[str], allow_rejected: bool = False | |||||
self, event_ids: StrCollection, allow_rejected: bool = False | |||||
) -> Awaitable[Dict[str, EventBase]]: | ) -> Awaitable[Dict[str, EventBase]]: | ||||
"""Get events from the database | """Get events from the database | ||||
@@ -17,7 +17,6 @@ import logging | |||||
from typing import ( | from typing import ( | ||||
Awaitable, | Awaitable, | ||||
Callable, | Callable, | ||||
Collection, | |||||
Dict, | Dict, | ||||
Iterable, | Iterable, | ||||
List, | List, | ||||
@@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes | |||||
from synapse.api.errors import AuthError | from synapse.api.errors import AuthError | ||||
from synapse.api.room_versions import RoomVersion | from synapse.api.room_versions import RoomVersion | ||||
from synapse.events import EventBase | from synapse.events import EventBase | ||||
from synapse.types import MutableStateMap, StateMap | |||||
from synapse.types import MutableStateMap, StateMap, StrCollection | |||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
@@ -45,7 +44,7 @@ async def resolve_events_with_store( | |||||
room_version: RoomVersion, | room_version: RoomVersion, | ||||
state_sets: Sequence[StateMap[str]], | state_sets: Sequence[StateMap[str]], | ||||
event_map: Optional[Dict[str, EventBase]], | event_map: Optional[Dict[str, EventBase]], | ||||
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]], | |||||
state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]], | |||||
) -> StateMap[str]: | ) -> StateMap[str]: | ||||
""" | """ | ||||
Args: | Args: | ||||
@@ -19,7 +19,6 @@ from typing import ( | |||||
Any, | Any, | ||||
Awaitable, | Awaitable, | ||||
Callable, | Callable, | ||||
Collection, | |||||
Dict, | Dict, | ||||
Generator, | Generator, | ||||
Iterable, | Iterable, | ||||
@@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes | |||||
from synapse.api.errors import AuthError | from synapse.api.errors import AuthError | ||||
from synapse.api.room_versions import RoomVersion | from synapse.api.room_versions import RoomVersion | ||||
from synapse.events import EventBase | from synapse.events import EventBase | ||||
from synapse.types import MutableStateMap, StateMap | |||||
from synapse.types import MutableStateMap, StateMap, StrCollection | |||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
@@ -56,7 +55,7 @@ class StateResolutionStore(Protocol): | |||||
# This is usually synapse.state.StateResolutionStore, but it's replaced with a | # This is usually synapse.state.StateResolutionStore, but it's replaced with a | ||||
# TestStateResolutionStore in tests. | # TestStateResolutionStore in tests. | ||||
def get_events( | def get_events( | ||||
self, event_ids: Collection[str], allow_rejected: bool = False | |||||
self, event_ids: StrCollection, allow_rejected: bool = False | |||||
) -> Awaitable[Dict[str, EventBase]]: | ) -> Awaitable[Dict[str, EventBase]]: | ||||
... | ... | ||||
@@ -366,7 +365,7 @@ async def _get_auth_chain_difference( | |||||
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) | union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) | ||||
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) | intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) | ||||
auth_difference_unpersisted_part: Collection[str] = union - intersection | |||||
auth_difference_unpersisted_part: StrCollection = union - intersection | |||||
else: | else: | ||||
auth_difference_unpersisted_part = () | auth_difference_unpersisted_part = () | ||||
state_sets_ids = [set(state_set.values()) for state_set in state_sets] | state_sets_ids = [set(state_set.values()) for state_set in state_sets] | ||||
@@ -47,7 +47,7 @@ from synapse.storage.database import ( | |||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore | from synapse.storage.databases.main.events_worker import EventsWorkerStore | ||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore | from synapse.storage.databases.main.signatures import SignatureWorkerStore | ||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine | from synapse.storage.engines import PostgresEngine, Sqlite3Engine | ||||
from synapse.types import JsonDict, StrCollection | |||||
from synapse.types import JsonDict, StrCollection, StrSequence | |||||
from synapse.util import json_encoder | from synapse.util import json_encoder | ||||
from synapse.util.caches.descriptors import cached | from synapse.util.caches.descriptors import cached | ||||
from synapse.util.caches.lrucache import LruCache | from synapse.util.caches.lrucache import LruCache | ||||
@@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas | |||||
) | ) | ||||
@cached(max_entries=5000, iterable=True) | @cached(max_entries=5000, iterable=True) | ||||
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]: | |||||
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence: | |||||
return await self.db_pool.simple_select_onecol( | return await self.db_pool.simple_select_onecol( | ||||
table="event_forward_extremities", | table="event_forward_extremities", | ||||
keyvalues={"room_id": room_id}, | keyvalues={"room_id": room_id}, | ||||
@@ -36,7 +36,7 @@ from synapse.events.utils import prune_event | |||||
from synapse.logging.opentracing import trace | from synapse.logging.opentracing import trace | ||||
from synapse.storage.controllers import StorageControllers | from synapse.storage.controllers import StorageControllers | ||||
from synapse.storage.databases.main import DataStore | from synapse.storage.databases.main import DataStore | ||||
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id | |||||
from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id | |||||
from synapse.types.state import StateFilter | from synapse.types.state import StateFilter | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
@@ -150,12 +150,12 @@ async def filter_events_for_client( | |||||
async def filter_event_for_clients_with_state( | async def filter_event_for_clients_with_state( | ||||
store: DataStore, | store: DataStore, | ||||
user_ids: Collection[str], | |||||
user_ids: StrCollection, | |||||
event: EventBase, | event: EventBase, | ||||
context: EventContext, | context: EventContext, | ||||
is_peeking: bool = False, | is_peeking: bool = False, | ||||
filter_send_to_client: bool = True, | filter_send_to_client: bool = True, | ||||
) -> Collection[str]: | |||||
) -> StrCollection: | |||||
""" | """ | ||||
Checks to see if an event is visible to the users in the list at the time of | Checks to see if an event is visible to the users in the list at the time of | ||||
the event. | the event. | ||||