Except `synapse/events/__init__.py`, which will be done in a follow-up.tags/v1.46.0rc1
@@ -0,0 +1 @@ | |||
Add type hints to `synapse.events`. |
@@ -22,8 +22,11 @@ files = | |||
synapse/crypto, | |||
synapse/event_auth.py, | |||
synapse/events/builder.py, | |||
synapse/events/presence_router.py, | |||
synapse/events/snapshot.py, | |||
synapse/events/spamcheck.py, | |||
synapse/events/third_party_rules.py, | |||
synapse/events/utils.py, | |||
synapse/events/validator.py, | |||
synapse/federation, | |||
synapse/groups, | |||
@@ -96,6 +99,9 @@ files = | |||
tests/util/test_itertools.py, | |||
tests/util/test_stream_change_cache.py | |||
[mypy-synapse.events.*] | |||
disallow_untyped_defs = True | |||
[mypy-synapse.handlers.*] | |||
disallow_untyped_defs = True | |||
@@ -90,13 +90,13 @@ class EventBuilder: | |||
) | |||
@property | |||
def state_key(self): | |||
def state_key(self) -> str: | |||
if self._state_key is not None: | |||
return self._state_key | |||
raise AttributeError("state_key") | |||
def is_state(self): | |||
def is_state(self) -> bool: | |||
return self._state_key is not None | |||
async def build( | |||
@@ -14,6 +14,7 @@ | |||
import logging | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Awaitable, | |||
Callable, | |||
Dict, | |||
@@ -33,14 +34,13 @@ if TYPE_CHECKING: | |||
GET_USERS_FOR_STATES_CALLBACK = Callable[ | |||
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]] | |||
] | |||
GET_INTERESTED_USERS_CALLBACK = Callable[ | |||
[str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]] | |||
] | |||
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS. | |||
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]] | |||
logger = logging.getLogger(__name__) | |||
def load_legacy_presence_router(hs: "HomeServer"): | |||
def load_legacy_presence_router(hs: "HomeServer") -> None: | |||
"""Wrapper that loads a presence router module configured using the old | |||
configuration, and registers the hooks they implement. | |||
""" | |||
@@ -69,9 +69,10 @@ def load_legacy_presence_router(hs: "HomeServer"): | |||
if f is None: | |||
return None | |||
def run(*args, **kwargs): | |||
# mypy doesn't do well across function boundaries so we need to tell it | |||
# f is definitely not None. | |||
def run(*args: Any, **kwargs: Any) -> Awaitable: | |||
# Assertion required because mypy can't prove we won't change `f` | |||
# back to `None`. See | |||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions | |||
assert f is not None | |||
return maybe_awaitable(f(*args, **kwargs)) | |||
@@ -104,7 +105,7 @@ class PresenceRouter: | |||
self, | |||
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None, | |||
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None, | |||
): | |||
) -> None: | |||
# PresenceRouter modules are required to implement both of these methods | |||
# or neither of them as they are assumed to act in a complementary manner | |||
paired_methods = [get_users_for_states, get_interested_users] | |||
@@ -142,7 +143,7 @@ class PresenceRouter: | |||
# Don't include any extra destinations for presence updates | |||
return {} | |||
users_for_states = {} | |||
users_for_states: Dict[str, Set[UserPresenceState]] = {} | |||
# run all the callbacks for get_users_for_states and combine the results | |||
for callback in self._get_users_for_states_callbacks: | |||
try: | |||
@@ -171,7 +172,7 @@ class PresenceRouter: | |||
return users_for_states | |||
async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]: | |||
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]: | |||
""" | |||
Retrieve a list of users that `user_id` is interested in receiving the | |||
presence of. This will be in addition to those they share a room with. | |||
@@ -11,17 +11,20 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING, Optional, Union | |||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union | |||
import attr | |||
from frozendict import frozendict | |||
from twisted.internet.defer import Deferred | |||
from synapse.appservice import ApplicationService | |||
from synapse.events import EventBase | |||
from synapse.logging.context import make_deferred_yieldable, run_in_background | |||
from synapse.types import StateMap | |||
from synapse.types import JsonDict, StateMap | |||
if TYPE_CHECKING: | |||
from synapse.storage import Storage | |||
from synapse.storage.databases.main import DataStore | |||
@@ -112,13 +115,13 @@ class EventContext: | |||
@staticmethod | |||
def with_state( | |||
state_group, | |||
state_group_before_event, | |||
current_state_ids, | |||
prev_state_ids, | |||
prev_group=None, | |||
delta_ids=None, | |||
): | |||
state_group: Optional[int], | |||
state_group_before_event: Optional[int], | |||
current_state_ids: Optional[StateMap[str]], | |||
prev_state_ids: Optional[StateMap[str]], | |||
prev_group: Optional[int] = None, | |||
delta_ids: Optional[StateMap[str]] = None, | |||
) -> "EventContext": | |||
return EventContext( | |||
current_state_ids=current_state_ids, | |||
prev_state_ids=prev_state_ids, | |||
@@ -129,22 +132,22 @@ class EventContext: | |||
) | |||
@staticmethod | |||
def for_outlier(): | |||
def for_outlier() -> "EventContext": | |||
"""Return an EventContext instance suitable for persisting an outlier event""" | |||
return EventContext( | |||
current_state_ids={}, | |||
prev_state_ids={}, | |||
) | |||
async def serialize(self, event: EventBase, store: "DataStore") -> dict: | |||
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: | |||
"""Converts self to a type that can be serialized as JSON, and then | |||
deserialized by `deserialize` | |||
Args: | |||
event (FrozenEvent): The event that this context relates to | |||
event: The event that this context relates to | |||
Returns: | |||
dict | |||
The serialized event. | |||
""" | |||
# We don't serialize the full state dicts, instead they get pulled out | |||
@@ -170,17 +173,16 @@ class EventContext: | |||
} | |||
@staticmethod | |||
def deserialize(storage, input): | |||
def deserialize(storage: "Storage", input: JsonDict) -> "EventContext": | |||
"""Converts a dict that was produced by `serialize` back into a | |||
EventContext. | |||
Args: | |||
storage (Storage): Used to convert AS ID to AS object and fetch | |||
state. | |||
input (dict): A dict produced by `serialize` | |||
storage: Used to convert AS ID to AS object and fetch state. | |||
input: A dict produced by `serialize` | |||
Returns: | |||
EventContext | |||
The event context. | |||
""" | |||
context = _AsyncEventContextImpl( | |||
# We use the state_group and prev_state_id stuff to pull the | |||
@@ -241,22 +243,25 @@ class EventContext: | |||
await self._ensure_fetched() | |||
return self._current_state_ids | |||
async def get_prev_state_ids(self): | |||
async def get_prev_state_ids(self) -> StateMap[str]: | |||
""" | |||
Gets the room state map, excluding this event. | |||
For a non-state event, this will be the same as get_current_state_ids(). | |||
Returns: | |||
dict[(str, str), str]|None: Returns None if state_group | |||
is None, which happens when the associated event is an outlier. | |||
Maps a (type, state_key) to the event ID of the state event matching | |||
this tuple. | |||
Returns {} if state_group is None, which happens when the associated | |||
event is an outlier. | |||
Maps a (type, state_key) to the event ID of the state event matching | |||
this tuple. | |||
""" | |||
await self._ensure_fetched() | |||
# There *should* be previous state IDs now. | |||
assert self._prev_state_ids is not None | |||
return self._prev_state_ids | |||
def get_cached_current_state_ids(self): | |||
def get_cached_current_state_ids(self) -> Optional[StateMap[str]]: | |||
"""Gets the current state IDs if we have them already cached. | |||
It is an error to access this for a rejected event, since rejected state should | |||
@@ -264,16 +269,17 @@ class EventContext: | |||
``rejected`` is set. | |||
Returns: | |||
dict[(str, str), str]|None: Returns None if we haven't cached the | |||
state or if state_group is None, which happens when the associated | |||
event is an outlier. | |||
Returns None if we haven't cached the state or if state_group is None | |||
(which happens when the associated event is an outlier). | |||
Otherwise, returns the the current state IDs. | |||
""" | |||
if self.rejected: | |||
raise RuntimeError("Attempt to access state_ids of rejected event") | |||
return self._current_state_ids | |||
async def _ensure_fetched(self): | |||
async def _ensure_fetched(self) -> None: | |||
return None | |||
@@ -285,46 +291,46 @@ class _AsyncEventContextImpl(EventContext): | |||
Attributes: | |||
_storage (Storage) | |||
_storage | |||
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have | |||
been calculated. None if we haven't started calculating yet | |||
_fetching_state_deferred: Resolves when *_state_ids have been calculated. | |||
None if we haven't started calculating yet | |||
_event_type (str): The type of the event the context is associated with. | |||
_event_type: The type of the event the context is associated with. | |||
_event_state_key (str): The state_key of the event the context is | |||
associated with. | |||
_event_state_key: The state_key of the event the context is associated with. | |||
_prev_state_id (str|None): If the event associated with the context is | |||
a state event, then `_prev_state_id` is the event_id of the state | |||
that was replaced. | |||
_prev_state_id: If the event associated with the context is a state event, | |||
then `_prev_state_id` is the event_id of the state that was replaced. | |||
""" | |||
# This needs to have a default as we're inheriting | |||
_storage = attr.ib(default=None) | |||
_prev_state_id = attr.ib(default=None) | |||
_event_type = attr.ib(default=None) | |||
_event_state_key = attr.ib(default=None) | |||
_fetching_state_deferred = attr.ib(default=None) | |||
_storage: "Storage" = attr.ib(default=None) | |||
_prev_state_id: Optional[str] = attr.ib(default=None) | |||
_event_type: str = attr.ib(default=None) | |||
_event_state_key: Optional[str] = attr.ib(default=None) | |||
_fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None) | |||
async def _ensure_fetched(self): | |||
async def _ensure_fetched(self) -> None: | |||
if not self._fetching_state_deferred: | |||
self._fetching_state_deferred = run_in_background(self._fill_out_state) | |||
return await make_deferred_yieldable(self._fetching_state_deferred) | |||
await make_deferred_yieldable(self._fetching_state_deferred) | |||
async def _fill_out_state(self): | |||
async def _fill_out_state(self) -> None: | |||
"""Called to populate the _current_state_ids and _prev_state_ids | |||
attributes by loading from the database. | |||
""" | |||
if self.state_group is None: | |||
return | |||
self._current_state_ids = await self._storage.state.get_state_ids_for_group( | |||
current_state_ids = await self._storage.state.get_state_ids_for_group( | |||
self.state_group | |||
) | |||
# Set this separately so mypy knows current_state_ids is not None. | |||
self._current_state_ids = current_state_ids | |||
if self._event_state_key is not None: | |||
self._prev_state_ids = dict(self._current_state_ids) | |||
self._prev_state_ids = dict(current_state_ids) | |||
key = (self._event_type, self._event_state_key) | |||
if self._prev_state_id: | |||
@@ -332,10 +338,12 @@ class _AsyncEventContextImpl(EventContext): | |||
else: | |||
self._prev_state_ids.pop(key, None) | |||
else: | |||
self._prev_state_ids = self._current_state_ids | |||
self._prev_state_ids = current_state_ids | |||
def _encode_state_dict(state_dict): | |||
def _encode_state_dict( | |||
state_dict: Optional[StateMap[str]], | |||
) -> Optional[List[Tuple[str, str, str]]]: | |||
"""Since dicts of (type, state_key) -> event_id cannot be serialized in | |||
JSON we need to convert them to a form that can. | |||
""" | |||
@@ -345,7 +353,9 @@ def _encode_state_dict(state_dict): | |||
return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()] | |||
def _decode_state_dict(input): | |||
def _decode_state_dict( | |||
input: Optional[List[Tuple[str, str, str]]] | |||
) -> Optional[StateMap[str]]: | |||
"""Decodes a state dict encoded using `_encode_state_dict` above""" | |||
if input is None: | |||
return None | |||
@@ -77,7 +77,7 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[ | |||
] | |||
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): | |||
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None: | |||
"""Wrapper that loads spam checkers configured using the old configuration, and | |||
registers the spam checker hooks they implement. | |||
""" | |||
@@ -129,9 +129,9 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): | |||
request_info: Collection[Tuple[str, str]], | |||
auth_provider_id: Optional[str], | |||
) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]: | |||
# We've already made sure f is not None above, but mypy doesn't | |||
# do well across function boundaries so we need to tell it f is | |||
# definitely not None. | |||
# Assertion required because mypy can't prove we won't | |||
# change `f` back to `None`. See | |||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions | |||
assert f is not None | |||
return f( | |||
@@ -146,9 +146,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): | |||
"Bad signature for callback check_registration_for_spam", | |||
) | |||
def run(*args, **kwargs): | |||
# mypy doesn't do well across function boundaries so we need to tell it | |||
# wrapped_func is definitely not None. | |||
def run(*args: Any, **kwargs: Any) -> Awaitable: | |||
# Assertion required because mypy can't prove we won't change `f` | |||
# back to `None`. See | |||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions | |||
assert wrapped_func is not None | |||
return maybe_awaitable(wrapped_func(*args, **kwargs)) | |||
@@ -165,7 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): | |||
class SpamChecker: | |||
def __init__(self): | |||
def __init__(self) -> None: | |||
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] | |||
self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = [] | |||
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] | |||
@@ -209,7 +210,7 @@ class SpamChecker: | |||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK | |||
] = None, | |||
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, | |||
): | |||
) -> None: | |||
"""Register callbacks from module for each hook.""" | |||
if check_event_for_spam is not None: | |||
self._check_event_for_spam_callbacks.append(check_event_for_spam) | |||
@@ -275,7 +276,9 @@ class SpamChecker: | |||
return False | |||
async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool): | |||
async def user_may_join_room( | |||
self, user_id: str, room_id: str, is_invited: bool | |||
) -> bool: | |||
"""Checks if a given users is allowed to join a room. | |||
Not called when a user creates a room. | |||
@@ -285,7 +288,7 @@ class SpamChecker: | |||
is_invited: Whether the user is invited into the room | |||
Returns: | |||
bool: Whether the user may join the room | |||
Whether the user may join the room | |||
""" | |||
for callback in self._user_may_join_room_callbacks: | |||
if await callback(user_id, room_id, is_invited) is False: | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple | |||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple | |||
from synapse.api.errors import SynapseError | |||
from synapse.events import EventBase | |||
@@ -38,7 +38,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ | |||
] | |||
def load_legacy_third_party_event_rules(hs: "HomeServer"): | |||
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: | |||
"""Wrapper that loads a third party event rules module configured using the old | |||
configuration, and registers the hooks they implement. | |||
""" | |||
@@ -77,9 +77,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"): | |||
event: EventBase, | |||
state_events: StateMap[EventBase], | |||
) -> Tuple[bool, Optional[dict]]: | |||
# We've already made sure f is not None above, but mypy doesn't do well | |||
# across function boundaries so we need to tell it f is definitely not | |||
# None. | |||
# Assertion required because mypy can't prove we won't change | |||
# `f` back to `None`. See | |||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions | |||
assert f is not None | |||
res = await f(event, state_events) | |||
@@ -98,9 +98,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"): | |||
async def wrap_on_create_room( | |||
requester: Requester, config: dict, is_requester_admin: bool | |||
) -> None: | |||
# We've already made sure f is not None above, but mypy doesn't do well | |||
# across function boundaries so we need to tell it f is definitely not | |||
# None. | |||
# Assertion required because mypy can't prove we won't change | |||
# `f` back to `None`. See | |||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions | |||
assert f is not None | |||
res = await f(requester, config, is_requester_admin) | |||
@@ -112,9 +112,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"): | |||
return wrap_on_create_room | |||
def run(*args, **kwargs): | |||
# mypy doesn't do well across function boundaries so we need to tell it | |||
# f is definitely not None. | |||
def run(*args: Any, **kwargs: Any) -> Awaitable: | |||
# Assertion required because mypy can't prove we won't change `f` | |||
# back to `None`. See | |||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions | |||
assert f is not None | |||
return maybe_awaitable(f(*args, **kwargs)) | |||
@@ -162,7 +163,7 @@ class ThirdPartyEventRules: | |||
check_visibility_can_be_modified: Optional[ | |||
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK | |||
] = None, | |||
): | |||
) -> None: | |||
"""Register callbacks from modules for each hook.""" | |||
if check_event_allowed is not None: | |||
self._check_event_allowed_callbacks.append(check_event_allowed) | |||
@@ -13,18 +13,32 @@ | |||
# limitations under the License. | |||
import collections.abc | |||
import re | |||
from typing import Any, Mapping, Union | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Callable, | |||
Dict, | |||
Iterable, | |||
List, | |||
Mapping, | |||
Optional, | |||
Union, | |||
) | |||
from frozendict import frozendict | |||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes | |||
from synapse.api.errors import Codes, SynapseError | |||
from synapse.api.room_versions import RoomVersion | |||
from synapse.types import JsonDict | |||
from synapse.util.async_helpers import yieldable_gather_results | |||
from synapse.util.frozenutils import unfreeze | |||
from . import EventBase | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' | |||
# (?<!stuff) matches if the current position in the string is not preceded | |||
# by a match for 'stuff'. | |||
@@ -65,7 +79,7 @@ def prune_event(event: EventBase) -> EventBase: | |||
return pruned_event | |||
def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: | |||
def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict: | |||
"""Redacts the event_dict in the same way as `prune_event`, except it | |||
operates on dicts rather than event objects | |||
@@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: | |||
new_content = {} | |||
def add_fields(*fields): | |||
def add_fields(*fields: str) -> None: | |||
for field in fields: | |||
if field in event_dict["content"]: | |||
new_content[field] = event_dict["content"][field] | |||
@@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: | |||
allowed_fields["content"] = new_content | |||
unsigned = {} | |||
unsigned: JsonDict = {} | |||
allowed_fields["unsigned"] = unsigned | |||
event_unsigned = event_dict.get("unsigned", {}) | |||
@@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: | |||
return allowed_fields | |||
def _copy_field(src, dst, field): | |||
def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None: | |||
"""Copy the field in 'src' to 'dst'. | |||
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"] | |||
then dst={"foo":{"bar":5}}. | |||
Args: | |||
src(dict): The dict to read from. | |||
dst(dict): The dict to modify. | |||
field(list<str>): List of keys to drill down to in 'src'. | |||
src: The dict to read from. | |||
dst: The dict to modify. | |||
field: List of keys to drill down to in 'src'. | |||
""" | |||
if len(field) == 0: # this should be impossible | |||
return | |||
@@ -205,7 +219,7 @@ def _copy_field(src, dst, field): | |||
sub_out_dict[key_to_move] = sub_dict[key_to_move] | |||
def only_fields(dictionary, fields): | |||
def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict: | |||
"""Return a new dict with only the fields in 'dictionary' which are present | |||
in 'fields'. | |||
@@ -215,11 +229,11 @@ def only_fields(dictionary, fields): | |||
A literal '.' character in a field name may be escaped using a '\'. | |||
Args: | |||
dictionary(dict): The dictionary to read from. | |||
fields(list<str>): A list of fields to copy over. Only shallow refs are | |||
dictionary: The dictionary to read from. | |||
fields: A list of fields to copy over. Only shallow refs are | |||
taken. | |||
Returns: | |||
dict: A new dictionary with only the given fields. If fields was empty, | |||
A new dictionary with only the given fields. If fields was empty, | |||
the same dictionary is returned. | |||
""" | |||
if len(fields) == 0: | |||
@@ -235,17 +249,17 @@ def only_fields(dictionary, fields): | |||
[f.replace(r"\.", r".") for f in field_array] for field_array in split_fields | |||
] | |||
output = {} | |||
output: JsonDict = {} | |||
for field_array in split_fields: | |||
_copy_field(dictionary, output, field_array) | |||
return output | |||
def format_event_raw(d): | |||
def format_event_raw(d: JsonDict) -> JsonDict: | |||
return d | |||
def format_event_for_client_v1(d): | |||
def format_event_for_client_v1(d: JsonDict) -> JsonDict: | |||
d = format_event_for_client_v2(d) | |||
sender = d.get("sender") | |||
@@ -267,7 +281,7 @@ def format_event_for_client_v1(d): | |||
return d | |||
def format_event_for_client_v2(d): | |||
def format_event_for_client_v2(d: JsonDict) -> JsonDict: | |||
drop_keys = ( | |||
"auth_events", | |||
"prev_events", | |||
@@ -282,37 +296,37 @@ def format_event_for_client_v2(d): | |||
return d | |||
def format_event_for_client_v2_without_room_id(d): | |||
def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: | |||
d = format_event_for_client_v2(d) | |||
d.pop("room_id", None) | |||
return d | |||
def serialize_event( | |||
e, | |||
time_now_ms, | |||
as_client_event=True, | |||
event_format=format_event_for_client_v1, | |||
token_id=None, | |||
only_event_fields=None, | |||
include_stripped_room_state=False, | |||
): | |||
e: Union[JsonDict, EventBase], | |||
time_now_ms: int, | |||
as_client_event: bool = True, | |||
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, | |||
token_id: Optional[str] = None, | |||
only_event_fields: Optional[List[str]] = None, | |||
include_stripped_room_state: bool = False, | |||
) -> JsonDict: | |||
"""Serialize event for clients | |||
Args: | |||
e (EventBase) | |||
time_now_ms (int) | |||
as_client_event (bool) | |||
e | |||
time_now_ms | |||
as_client_event | |||
event_format | |||
token_id | |||
only_event_fields | |||
include_stripped_room_state (bool): Some events can have stripped room state | |||
include_stripped_room_state: Some events can have stripped room state | |||
stored in the `unsigned` field. This is required for invite and knock | |||
functionality. If this option is False, that state will be removed from the | |||
event before it is returned. Otherwise, it will be kept. | |||
Returns: | |||
dict | |||
The serialized event dictionary. | |||
""" | |||
# FIXME(erikj): To handle the case of presence events and the like | |||
@@ -369,25 +383,29 @@ class EventClientSerializer: | |||
clients. | |||
""" | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
self.store = hs.get_datastore() | |||
self.experimental_msc1849_support_enabled = ( | |||
hs.config.server.experimental_msc1849_support_enabled | |||
) | |||
async def serialize_event( | |||
self, event, time_now, bundle_aggregations=True, **kwargs | |||
): | |||
self, | |||
event: Union[JsonDict, EventBase], | |||
time_now: int, | |||
bundle_aggregations: bool = True, | |||
**kwargs: Any, | |||
) -> JsonDict: | |||
"""Serializes a single event. | |||
Args: | |||
event (EventBase) | |||
time_now (int): The current time in milliseconds | |||
bundle_aggregations (bool): Whether to bundle in related events | |||
event | |||
time_now: The current time in milliseconds | |||
bundle_aggregations: Whether to bundle in related events | |||
**kwargs: Arguments to pass to `serialize_event` | |||
Returns: | |||
dict: The serialized event | |||
The serialized event | |||
""" | |||
# To handle the case of presence events and the like | |||
if not isinstance(event, EventBase): | |||
@@ -448,25 +466,27 @@ class EventClientSerializer: | |||
return serialized_event | |||
def serialize_events(self, events, time_now, **kwargs): | |||
async def serialize_events( | |||
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any | |||
) -> List[JsonDict]: | |||
"""Serializes multiple events. | |||
Args: | |||
event (iter[EventBase]) | |||
time_now (int): The current time in milliseconds | |||
event | |||
time_now: The current time in milliseconds | |||
**kwargs: Arguments to pass to `serialize_event` | |||
Returns: | |||
Deferred[list[dict]]: The list of serialized events | |||
The list of serialized events | |||
""" | |||
return yieldable_gather_results( | |||
return await yieldable_gather_results( | |||
self.serialize_event, events, time_now=time_now, **kwargs | |||
) | |||
def copy_power_levels_contents( | |||
old_power_levels: Mapping[str, Union[int, Mapping[str, int]]] | |||
): | |||
) -> Dict[str, Union[int, Dict[str, int]]]: | |||
"""Copy the content of a power_levels event, unfreezing frozendicts along the way | |||
Raises: | |||
@@ -475,7 +495,7 @@ def copy_power_levels_contents( | |||
if not isinstance(old_power_levels, collections.abc.Mapping): | |||
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) | |||
power_levels = {} | |||
power_levels: Dict[str, Union[int, Dict[str, int]]] = {} | |||
for k, v in old_power_levels.items(): | |||
if isinstance(v, int): | |||
@@ -483,7 +503,8 @@ def copy_power_levels_contents( | |||
continue | |||
if isinstance(v, collections.abc.Mapping): | |||
power_levels[k] = h = {} | |||
h: Dict[str, int] = {} | |||
power_levels[k] = h | |||
for k1, v1 in v.items(): | |||
# we should only have one level of nesting | |||
if not isinstance(v1, int): | |||
@@ -498,7 +519,7 @@ def copy_power_levels_contents( | |||
return power_levels | |||
def validate_canonicaljson(value: Any): | |||
def validate_canonicaljson(value: Any) -> None: | |||
""" | |||
Ensure that the JSON object is valid according to the rules of canonical JSON. | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import collections.abc | |||
from typing import Union | |||
from typing import Iterable, Union | |||
import jsonschema | |||
@@ -28,11 +28,11 @@ from synapse.events.utils import ( | |||
validate_canonicaljson, | |||
) | |||
from synapse.federation.federation_server import server_matches_acl_event | |||
from synapse.types import EventID, RoomID, UserID | |||
from synapse.types import EventID, JsonDict, RoomID, UserID | |||
class EventValidator: | |||
def validate_new(self, event: EventBase, config: HomeServerConfig): | |||
def validate_new(self, event: EventBase, config: HomeServerConfig) -> None: | |||
"""Validates the event has roughly the right format | |||
Args: | |||
@@ -116,7 +116,7 @@ class EventValidator: | |||
errcode=Codes.BAD_JSON, | |||
) | |||
def _validate_retention(self, event: EventBase): | |||
def _validate_retention(self, event: EventBase) -> None: | |||
"""Checks that an event that defines the retention policy for a room respects the | |||
format enforced by the spec. | |||
@@ -156,7 +156,7 @@ class EventValidator: | |||
errcode=Codes.BAD_JSON, | |||
) | |||
def validate_builder(self, event: Union[EventBase, EventBuilder]): | |||
def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None: | |||
"""Validates that the builder/event has roughly the right format. Only | |||
checks values that we expect a proto event to have, rather than all the | |||
fields an event would have | |||
@@ -204,14 +204,14 @@ class EventValidator: | |||
self._ensure_state_event(event) | |||
def _ensure_strings(self, d, keys): | |||
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None: | |||
for s in keys: | |||
if s not in d: | |||
raise SynapseError(400, "'%s' not in content" % (s,)) | |||
if not isinstance(d[s], str): | |||
raise SynapseError(400, "'%s' not a string type" % (s,)) | |||
def _ensure_state_event(self, event): | |||
def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None: | |||
if not event.is_state(): | |||
raise SynapseError(400, "'%s' must be state events" % (event.type,)) | |||
@@ -244,7 +244,9 @@ POWER_LEVELS_SCHEMA = { | |||
} | |||
def _create_power_level_validator(): | |||
# This could return something newer than Draft 7, but that's the current "latest" | |||
# validator. | |||
def _create_power_level_validator() -> jsonschema.Draft7Validator: | |||
validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA) | |||
# by default jsonschema does not consider a frozendict to be an object so | |||
@@ -465,17 +465,35 @@ class RoomCreationHandler: | |||
# the room has been created | |||
# Calculate the minimum power level needed to clone the room | |||
event_power_levels = power_levels.get("events", {}) | |||
if not isinstance(event_power_levels, dict): | |||
event_power_levels = {} | |||
state_default = power_levels.get("state_default", 50) | |||
try: | |||
state_default_int = int(state_default) # type: ignore[arg-type] | |||
except (TypeError, ValueError): | |||
state_default_int = 50 | |||
ban = power_levels.get("ban", 50) | |||
needed_power_level = max(state_default, ban, max(event_power_levels.values())) | |||
try: | |||
ban = int(ban) # type: ignore[arg-type] | |||
except (TypeError, ValueError): | |||
ban = 50 | |||
needed_power_level = max( | |||
state_default_int, ban, max(event_power_levels.values()) | |||
) | |||
# Get the user's current power level, this matches the logic in get_user_power_level, | |||
# but without the entire state map. | |||
user_power_levels = power_levels.setdefault("users", {}) | |||
if not isinstance(user_power_levels, dict): | |||
user_power_levels = {} | |||
users_default = power_levels.get("users_default", 0) | |||
current_power_level = user_power_levels.get(user_id, users_default) | |||
try: | |||
current_power_level_int = int(current_power_level) # type: ignore[arg-type] | |||
except (TypeError, ValueError): | |||
current_power_level_int = 0 | |||
# Raise the requester's power level in the new room if necessary | |||
if current_power_level < needed_power_level: | |||
if current_power_level_int < needed_power_level: | |||
user_power_levels[user_id] = needed_power_level | |||
await self._send_events_for_new_room( | |||
@@ -232,12 +232,12 @@ class RelationPaginationServlet(RestServlet): | |||
# Similarly, we don't allow relations to be applied to relations, so we | |||
# return the original relations without any aggregations on top of them | |||
# here. | |||
events = await self._event_serializer.serialize_events( | |||
serialized_events = await self._event_serializer.serialize_events( | |||
events, now, bundle_aggregations=False | |||
) | |||
return_value = pagination_chunk.to_dict() | |||
return_value["chunk"] = events | |||
return_value["chunk"] = serialized_events | |||
return_value["original_event"] = original_event | |||
return 200, return_value | |||
@@ -416,10 +416,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet): | |||
) | |||
now = self.clock.time_msec() | |||
events = await self._event_serializer.serialize_events(events, now) | |||
serialized_events = await self._event_serializer.serialize_events(events, now) | |||
return_value = result.to_dict() | |||
return_value["chunk"] = events | |||
return_value["chunk"] = serialized_events | |||
return 200, return_value | |||