Browse Source

Use StrCollection in additional places. (#16301)

tags/v1.93.0rc1
Patrick Cloke 7 months ago
committed by GitHub
parent
commit
d38d0dffc9
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 59 additions and 67 deletions
  1. +1
    -0
      changelog.d/16301.misc
  2. +5
    -7
      synapse/app/_base.py
  3. +1
    -2
      synapse/config/_base.py
  4. +2
    -3
      synapse/events/__init__.py
  5. +4
    -4
      synapse/events/builder.py
  6. +3
    -3
      synapse/events/validator.py
  7. +2
    -3
      synapse/http/client.py
  8. +16
    -17
      synapse/http/servlet.py
  9. +4
    -4
      synapse/metrics/__init__.py
  10. +3
    -3
      synapse/notifier.py
  11. +2
    -2
      synapse/rest/client/_base.py
  12. +6
    -7
      synapse/state/__init__.py
  13. +2
    -3
      synapse/state/v1.py
  14. +3
    -4
      synapse/state/v2.py
  15. +2
    -2
      synapse/storage/databases/main/event_federation.py
  16. +3
    -3
      synapse/visibility.py

+ 1
- 0
changelog.d/16301.misc View File

@@ -0,0 +1 @@
Improve type hints.

+ 5
- 7
synapse/app/_base.py View File

@@ -27,9 +27,7 @@ from typing import (
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
NoReturn,
Optional,
@@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
load_legacy_third_party_event_rules,
)
from synapse.types import ISynapseReactor
from synapse.types import ISynapseReactor, StrCollection
from synapse.util import SYNAPSE_VERSION
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
from synapse.util.daemonize import daemonize_process
@@ -278,7 +276,7 @@ def register_start(
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))


def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
"""
Start Prometheus metrics server.
"""
@@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:


def listen_manhole(
bind_addresses: Collection[str],
bind_addresses: StrCollection,
port: int,
manhole_settings: ManholeConfig,
manhole_globals: dict,
@@ -339,7 +337,7 @@ def listen_manhole(


def listen_tcp(
bind_addresses: Collection[str],
bind_addresses: StrCollection,
port: int,
factory: ServerFactory,
reactor: IReactorTCP = reactor,
@@ -448,7 +446,7 @@ def listen_http(


def listen_ssl(
bind_addresses: Collection[str],
bind_addresses: StrCollection,
port: int,
factory: ServerFactory,
context_factory: IOpenSSLContextFactory,


+ 1
- 2
synapse/config/_base.py View File

@@ -26,7 +26,6 @@ from textwrap import dedent
from typing import (
Any,
ClassVar,
Collection,
Dict,
Iterable,
Iterator,
@@ -384,7 +383,7 @@ class RootConfig:

config_classes: List[Type[Config]] = []

def __init__(self, config_files: Collection[str] = ()):
def __init__(self, config_files: StrSequence = ()):
# Capture absolute paths here, so we can reload config after we daemonize.
self.config_files = [os.path.abspath(path) for path in config_files]



+ 2
- 3
synapse/events/__init__.py View File

@@ -25,7 +25,6 @@ from typing import (
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
@@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta):
def keys(self) -> Iterable[str]:
return self._dict.keys()

def prev_event_ids(self) -> Sequence[str]:
def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.

@@ -553,7 +552,7 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id

def prev_event_ids(self) -> Sequence[str]:
def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.



+ 4
- 4
synapse/events/builder.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import attr
from signedjson.types import SigningKey
@@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import EventID, JsonDict
from synapse.types import EventID, JsonDict, StrCollection
from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -103,7 +103,7 @@ class EventBuilder:

async def build(
self,
prev_event_ids: Collection[str],
prev_event_ids: StrCollection,
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
@@ -136,7 +136,7 @@ class EventBuilder:

format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)


+ 3
- 3
synapse/events/validator.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
from typing import Iterable, List, Type, Union, cast
from typing import List, Type, Union, cast

import jsonschema
from pydantic import Field, StrictBool, StrictStr
@@ -36,7 +36,7 @@ from synapse.events.utils import (
from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel
from synapse.types import EventID, JsonDict, RoomID, UserID
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID


class EventValidator:
@@ -225,7 +225,7 @@ class EventValidator:

self._ensure_state_event(event)

def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))


+ 2
- 3
synapse/http/client.py View File

@@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor
from synapse.types import ISynapseReactor, StrSequence
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred

@@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
# the value actually has to be a List, but List is invariant so we can't specify that
# the entries can either be Lists or bytes.
RawHeaderValue = Union[
List[str],
StrSequence,
List[bytes],
List[Union[str, bytes]],
Tuple[str, ...],
Tuple[bytes, ...],
Tuple[Union[str, bytes], ...],
]


+ 16
- 17
synapse/http/servlet.py View File

@@ -18,7 +18,6 @@ import logging
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Iterable,
List,
Mapping,
Optional,
@@ -38,7 +37,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http import redact_uri
from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID
from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
from synapse.util import json_decoder

if TYPE_CHECKING:
@@ -340,7 +339,7 @@ def parse_string(
name: str,
default: str,
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -352,7 +351,7 @@ def parse_string(
name: str,
*,
required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -365,7 +364,7 @@ def parse_string(
*,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -376,7 +375,7 @@ def parse_string(
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
@@ -485,7 +484,7 @@ def parse_enum(

def _parse_string_value(
value: bytes,
allowed_values: Optional[Iterable[str]],
allowed_values: Optional[StrCollection],
name: str,
encoding: str,
) -> str:
@@ -511,7 +510,7 @@ def parse_strings_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
@@ -523,7 +522,7 @@ def parse_strings_from_args(
name: str,
default: List[str],
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
@@ -535,7 +534,7 @@ def parse_strings_from_args(
name: str,
*,
required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
@@ -548,7 +547,7 @@ def parse_strings_from_args(
default: Optional[List[str]] = None,
*,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
@@ -559,7 +558,7 @@ def parse_strings_from_args(
name: str,
default: Optional[List[str]] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
"""
@@ -610,7 +609,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -623,7 +622,7 @@ def parse_string_from_args(
default: Optional[str] = None,
*,
required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -635,7 +634,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -646,7 +645,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
@@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request(
return validate_json_object(content, model_type)


def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
absent = []
for k in required:
if k not in body:


+ 4
- 4
synapse/metrics/__init__.py View File

@@ -25,7 +25,6 @@ from typing import (
Iterable,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
@@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
from synapse.metrics._types import Collector
from synapse.types import StrSequence
from synapse.util import SYNAPSE_VERSION

logger = logging.getLogger(__name__)
@@ -81,7 +81,7 @@ class LaterGauge(Collector):

name: str
desc: str
labels: Optional[Sequence[str]] = attr.ib(hash=False)
labels: Optional[StrSequence] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
caller: Callable[
@@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
self,
name: str,
desc: str,
labels: Sequence[str],
sub_metrics: Sequence[str],
labels: StrSequence,
sub_metrics: StrSequence,
):
self.name = name
self.desc = desc


+ 3
- 3
synapse/notifier.py View File

@@ -104,7 +104,7 @@ class _NotifierUserStream:
def __init__(
self,
user_id: str,
rooms: Collection[str],
rooms: StrCollection,
current_token: StreamToken,
time_now_ms: int,
):
@@ -457,7 +457,7 @@ class Notifier:
stream_key: str,
new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
rooms: Optional[Collection[str]] = None,
rooms: Optional[StrCollection] = None,
) -> None:
"""Used to inform listeners that something has happened event wise.

@@ -529,7 +529,7 @@ class Notifier:
user_id: str,
timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
room_ids: Optional[Collection[str]] = None,
room_ids: Optional[StrCollection] = None,
from_token: StreamToken = StreamToken.START,
) -> T:
"""Wait until the callback returns a non empty response or the


+ 2
- 2
synapse/rest/client/_base.py View File

@@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar,

from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.types import JsonDict
from synapse.types import JsonDict, StrCollection

logger = logging.getLogger(__name__)


def client_patterns(
path_regex: str,
releases: Iterable[str] = ("r0", "v3"),
releases: StrCollection = ("r0", "v3"),
unstable: bool = True,
v1: bool = False,
) -> Iterable[Pattern]:


+ 6
- 7
synapse/state/__init__.py View File

@@ -20,7 +20,6 @@ from typing import (
Any,
Awaitable,
Callable,
Collection,
DefaultDict,
Dict,
FrozenSet,
@@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import StateMap
from synapse.types import StateMap, StrCollection
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -197,7 +196,7 @@ class StateHandler:
async def compute_state_after_events(
self,
room_id: str,
event_ids: Collection[str],
event_ids: StrCollection,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
@@ -231,7 +230,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter)

async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: Collection[str]
self, room_id: str, latest_event_ids: StrCollection
) -> Set[str]:
"""
Get the users IDs who are currently in a room.
@@ -256,7 +255,7 @@ class StateHandler:
return await self.store.get_joined_user_ids_from_state(room_id, state)

async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
self, room_id: str, event_ids: StrCollection
) -> FrozenSet[str]:
"""Get the hosts that were in a room at the given event ids

@@ -470,7 +469,7 @@ class StateHandler:
@trace
@measure_func()
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -882,7 +881,7 @@ class StateResolutionStore:
store: "DataStore"

def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False
self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database



+ 2
- 3
synapse/state/v1.py View File

@@ -17,7 +17,6 @@ import logging
from typing import (
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
@@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
from synapse.types import MutableStateMap, StateMap, StrCollection

logger = logging.getLogger(__name__)

@@ -45,7 +44,7 @@ async def resolve_events_with_store(
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:


+ 3
- 4
synapse/state/v2.py View File

@@ -19,7 +19,6 @@ from typing import (
Any,
Awaitable,
Callable,
Collection,
Dict,
Generator,
Iterable,
@@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
from synapse.types import MutableStateMap, StateMap, StrCollection

logger = logging.getLogger(__name__)

@@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
# TestStateResolutionStore in tests.
def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False
self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
...

@@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])

auth_difference_unpersisted_part: Collection[str] = union - intersection
auth_difference_unpersisted_part: StrCollection = union - intersection
else:
auth_difference_unpersisted_part = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets]


+ 2
- 2
synapse/storage/databases/main/event_federation.py View File

@@ -47,7 +47,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, StrCollection
from synapse.types import JsonDict, StrCollection, StrSequence
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)

@cached(max_entries=5000, iterable=True)
async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},


+ 3
- 3
synapse/visibility.py View File

@@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
from synapse.logging.opentracing import trace
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util import Clock

@@ -150,12 +150,12 @@ async def filter_events_for_client(

async def filter_event_for_clients_with_state(
store: DataStore,
user_ids: Collection[str],
user_ids: StrCollection,
event: EventBase,
context: EventContext,
is_peeking: bool = False,
filter_send_to_client: bool = True,
) -> Collection[str]:
) -> StrCollection:
"""
Checks to see if an event is visible to the users in the list at the time of
the event.


Loading…
Cancel
Save