@@ -0,0 +1 @@ | |||
Add type hints to `Notifier`. |
@@ -13,7 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, List | |||
from typing import TYPE_CHECKING, List, Tuple | |||
from canonicaljson import json | |||
@@ -54,7 +54,10 @@ class TransactionManager(object): | |||
@measure_func("_send_new_transaction") | |||
async def send_new_transaction( | |||
self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu] | |||
self, | |||
destination: str, | |||
pending_pdus: List[Tuple[EventBase, int]], | |||
pending_edus: List[Edu], | |||
): | |||
# Make a transaction-sending opentracing span. This span follows on from | |||
@@ -25,6 +25,7 @@ from typing import ( | |||
Set, | |||
Tuple, | |||
TypeVar, | |||
Union, | |||
) | |||
from prometheus_client import Counter | |||
@@ -186,7 +187,7 @@ class Notifier(object): | |||
self.store = hs.get_datastore() | |||
self.pending_new_room_events = ( | |||
[] | |||
) # type: List[Tuple[int, EventBase, Collection[str]]] | |||
) # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]] | |||
# Called when there are new things to stream over replication | |||
self.replication_callbacks = [] # type: List[Callable[[], None]] | |||
@@ -246,7 +247,7 @@ class Notifier(object): | |||
event: EventBase, | |||
room_stream_id: int, | |||
max_room_stream_id: int, | |||
extra_users: Collection[str] = [], | |||
extra_users: Collection[Union[str, UserID]] = [], | |||
): | |||
""" Used by handlers to inform the notifier something has happened | |||
in the room, room event wise. | |||
@@ -282,7 +283,10 @@ class Notifier(object): | |||
self._on_new_room_event(event, room_stream_id, extra_users) | |||
def _on_new_room_event( | |||
self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = [] | |||
self, | |||
event: EventBase, | |||
room_stream_id: int, | |||
extra_users: Collection[Union[str, UserID]] = [], | |||
): | |||
"""Notify any user streams that are interested in this room event""" | |||
# poke any interested application service. | |||
@@ -310,7 +314,7 @@ class Notifier(object): | |||
self, | |||
stream_key: str, | |||
new_token: int, | |||
users: Collection[str] = [], | |||
users: Collection[Union[str, UserID]] = [], | |||
rooms: Collection[str] = [], | |||
): | |||
""" Used to inform listeners that something has happened event wise. | |||
@@ -13,11 +13,12 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import abc | |||
import re | |||
import string | |||
import sys | |||
from collections import namedtuple | |||
from typing import Any, Dict, Tuple, TypeVar | |||
from typing import Any, Dict, Tuple, Type, TypeVar | |||
import attr | |||
from signedjson.key import decode_verify_key_bytes | |||
@@ -33,7 +34,7 @@ else: | |||
T_co = TypeVar("T_co", covariant=True) | |||
class Collection(Iterable[T_co], Container[T_co], Sized): | |||
class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore | |||
__slots__ = () | |||
@@ -141,6 +142,9 @@ def get_localpart_from_id(string): | |||
return string[1:idx] | |||
DS = TypeVar("DS", bound="DomainSpecificString") | |||
class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))): | |||
"""Common base class among ID/name strings that have a local part and a | |||
domain name, prefixed with a sigil. | |||
@@ -151,6 +155,10 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom | |||
'domain' : The domain part of the name | |||
""" | |||
__metaclass__ = abc.ABCMeta | |||
SIGIL = abc.abstractproperty() # type: str # type: ignore | |||
# Deny iteration because it will bite you if you try to create a singleton | |||
# set by: | |||
# users = set(user) | |||
@@ -166,7 +174,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom | |||
return self | |||
@classmethod | |||
def from_string(cls, s: str): | |||
def from_string(cls: Type[DS], s: str) -> DS: | |||
"""Parse the string given by 's' into a structure object.""" | |||
if len(s) < 1 or s[0:1] != cls.SIGIL: | |||
raise SynapseError( | |||
@@ -190,12 +198,12 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom | |||
# names on one HS | |||
return cls(localpart=parts[0], domain=domain) | |||
def to_string(self): | |||
def to_string(self) -> str: | |||
"""Return a string encoding the fields of the structure object.""" | |||
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) | |||
@classmethod | |||
def is_valid(cls, s): | |||
def is_valid(cls: Type[DS], s: str) -> bool: | |||
try: | |||
cls.from_string(s) | |||
return True | |||
@@ -235,8 +243,9 @@ class GroupID(DomainSpecificString): | |||
SIGIL = "+" | |||
@classmethod | |||
def from_string(cls, s): | |||
group_id = super(GroupID, cls).from_string(s) | |||
def from_string(cls: Type[DS], s: str) -> DS: | |||
group_id = super().from_string(s) # type: DS # type: ignore | |||
if not group_id.localpart: | |||
raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) | |||
@@ -15,6 +15,7 @@ | |||
import logging | |||
from functools import wraps | |||
from typing import Any, Callable, Optional, TypeVar, cast | |||
from prometheus_client import Counter | |||
@@ -57,8 +58,10 @@ in_flight = InFlightGauge( | |||
sub_metrics=["real_time_max", "real_time_sum"], | |||
) | |||
T = TypeVar("T", bound=Callable[..., Any]) | |||
def measure_func(name=None): | |||
def measure_func(name: Optional[str] = None) -> Callable[[T], T]: | |||
""" | |||
Used to decorate an async function with a `Measure` context manager. | |||
@@ -76,7 +79,7 @@ def measure_func(name=None): | |||
""" | |||
def wrapper(func): | |||
def wrapper(func: T) -> T: | |||
block_name = func.__name__ if name is None else name | |||
@wraps(func) | |||
@@ -85,7 +88,7 @@ def measure_func(name=None): | |||
r = await func(self, *args, **kwargs) | |||
return r | |||
return measured_func | |||
return cast(T, measured_func) | |||
return wrapper | |||
@@ -212,7 +212,9 @@ commands = mypy \ | |||
synapse/storage/state.py \ | |||
synapse/storage/util \ | |||
synapse/streams \ | |||
synapse/types.py \ | |||
synapse/util/caches/stream_change_cache.py \ | |||
synapse/util/metrics.py \ | |||
tests/replication \ | |||
tests/test_utils \ | |||
tests/rest/client/v2_alpha/test_auth.py \ | |||