@@ -0,0 +1 @@ | |||
Improve type annotations in Synapse's test suite. |
@@ -193,7 +193,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
@override_config({"limit_usage_by_mau": True}) | |||
def test_get_or_create_user_mau_not_blocked(self): | |||
self.store.count_monthly_users = Mock( | |||
# Type ignore: mypy doesn't like us assigning to methods. | |||
self.store.count_monthly_users = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) | |||
) | |||
# Ensure does not throw exception | |||
@@ -201,7 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
@override_config({"limit_usage_by_mau": True}) | |||
def test_get_or_create_user_mau_blocked(self): | |||
self.store.get_monthly_active_count = Mock( | |||
# Type ignore: mypy doesn't like us assigning to methods. | |||
self.store.get_monthly_active_count = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(self.lots_of_users) | |||
) | |||
self.get_failure( | |||
@@ -209,7 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
ResourceLimitError, | |||
) | |||
self.store.get_monthly_active_count = Mock( | |||
# Type ignore: mypy doesn't like us assigning to methods. | |||
self.store.get_monthly_active_count = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(self.hs.config.server.max_mau_value) | |||
) | |||
self.get_failure( | |||
@@ -28,11 +28,12 @@ from typing import ( | |||
MutableMapping, | |||
Optional, | |||
Tuple, | |||
Union, | |||
overload, | |||
) | |||
from unittest.mock import patch | |||
import attr | |||
from typing_extensions import Literal | |||
from twisted.web.resource import Resource | |||
from twisted.web.server import Site | |||
@@ -55,6 +56,32 @@ class RestHelper: | |||
site = attr.ib(type=Site) | |||
auth_user_id = attr.ib() | |||
@overload | |||
def create_room_as( | |||
self, | |||
room_creator: Optional[str] = ..., | |||
is_public: Optional[bool] = ..., | |||
room_version: Optional[str] = ..., | |||
tok: Optional[str] = ..., | |||
expect_code: Literal[200] = ..., | |||
extra_content: Optional[Dict] = ..., | |||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., | |||
) -> str: | |||
... | |||
@overload | |||
def create_room_as( | |||
self, | |||
room_creator: Optional[str] = ..., | |||
is_public: Optional[bool] = ..., | |||
room_version: Optional[str] = ..., | |||
tok: Optional[str] = ..., | |||
expect_code: int = ..., | |||
extra_content: Optional[Dict] = ..., | |||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ..., | |||
) -> Optional[str]: | |||
... | |||
def create_room_as( | |||
self, | |||
room_creator: Optional[str] = None, | |||
@@ -64,7 +91,7 @@ class RestHelper: | |||
expect_code: int = 200, | |||
extra_content: Optional[Dict] = None, | |||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, | |||
) -> str: | |||
) -> Optional[str]: | |||
""" | |||
Create a room. | |||
@@ -107,6 +134,8 @@ class RestHelper: | |||
if expect_code == 200: | |||
return channel.json_body["room_id"] | |||
else: | |||
return None | |||
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): | |||
self.change_membership( | |||
@@ -176,7 +205,7 @@ class RestHelper: | |||
extra_data: Optional[dict] = None, | |||
tok: Optional[str] = None, | |||
expect_code: int = 200, | |||
expect_errcode: str = None, | |||
expect_errcode: Optional[str] = None, | |||
) -> None: | |||
""" | |||
Send a membership state event into a room. | |||
@@ -260,9 +289,7 @@ class RestHelper: | |||
txn_id=None, | |||
tok=None, | |||
expect_code=200, | |||
custom_headers: Optional[ | |||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] | |||
] = None, | |||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, | |||
): | |||
if txn_id is None: | |||
txn_id = "m%s" % (str(time.time())) | |||
@@ -509,7 +536,7 @@ class RestHelper: | |||
went. | |||
""" | |||
cookies = {} | |||
cookies: Dict[str, str] = {} | |||
# if we're doing a ui auth, hit the ui auth redirect endpoint | |||
if ui_auth_session_id: | |||
@@ -631,7 +658,13 @@ class RestHelper: | |||
# hit the redirect url again with the right Host header, which should now issue | |||
# a cookie and redirect to the SSO provider. | |||
location = channel.headers.getRawHeaders("Location")[0] | |||
def get_location(channel: FakeChannel) -> str: | |||
location_values = channel.headers.getRawHeaders("Location") | |||
# Keep mypy happy by asserting that location_values is nonempty | |||
assert location_values | |||
return location_values[0] | |||
location = get_location(channel) | |||
parts = urllib.parse.urlsplit(location) | |||
channel = make_request( | |||
self.hs.get_reactor(), | |||
@@ -645,7 +678,7 @@ class RestHelper: | |||
assert channel.code == 302 | |||
channel.extract_cookies(cookies) | |||
return channel.headers.getRawHeaders("Location")[0] | |||
return get_location(channel) | |||
def initiate_sso_ui_auth( | |||
self, ui_auth_session_id: str, cookies: MutableMapping[str, str] | |||
@@ -24,6 +24,7 @@ from typing import ( | |||
MutableMapping, | |||
Optional, | |||
Tuple, | |||
Type, | |||
Union, | |||
) | |||
@@ -226,7 +227,7 @@ def make_request( | |||
path: Union[bytes, str], | |||
content: Union[bytes, str, JsonDict] = b"", | |||
access_token: Optional[str] = None, | |||
request: Request = SynapseRequest, | |||
request: Type[Request] = SynapseRequest, | |||
shorthand: bool = True, | |||
federation_auth_origin: Optional[bytes] = None, | |||
content_is_form: bool = False, | |||
@@ -44,6 +44,7 @@ from twisted.python.threadpool import ThreadPool | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from twisted.trial import unittest | |||
from twisted.web.resource import Resource | |||
from twisted.web.server import Request | |||
from synapse import events | |||
from synapse.api.constants import EventTypes, Membership | |||
@@ -95,16 +96,13 @@ def around(target): | |||
return _around | |||
T = TypeVar("T") | |||
class TestCase(unittest.TestCase): | |||
"""A subclass of twisted.trial's TestCase which looks for 'loglevel' | |||
attributes on both itself and its individual test methods, to override the | |||
root logger's logging level while that test (case|method) runs.""" | |||
def __init__(self, methodName, *args, **kwargs): | |||
super().__init__(methodName, *args, **kwargs) | |||
def __init__(self, methodName: str): | |||
super().__init__(methodName) | |||
method = getattr(self, methodName) | |||
@@ -220,16 +218,16 @@ class HomeserverTestCase(TestCase): | |||
Attributes: | |||
servlets: List of servlet registration function. | |||
user_id (str): The user ID to assume if auth is hijacked. | |||
hijack_auth (bool): Whether to hijack auth to return the user specified | |||
hijack_auth: Whether to hijack auth to return the user specified | |||
in user_id. | |||
""" | |||
hijack_auth = True | |||
needs_threadpool = False | |||
hijack_auth: ClassVar[bool] = True | |||
needs_threadpool: ClassVar[bool] = False | |||
servlets: ClassVar[List[RegisterServletsFunc]] = [] | |||
def __init__(self, methodName, *args, **kwargs): | |||
super().__init__(methodName, *args, **kwargs) | |||
def __init__(self, methodName: str): | |||
super().__init__(methodName) | |||
# see if we have any additional config for this test | |||
method = getattr(self, methodName) | |||
@@ -301,9 +299,10 @@ class HomeserverTestCase(TestCase): | |||
None, | |||
) | |||
self.hs.get_auth().get_user_by_req = get_user_by_req | |||
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token | |||
self.hs.get_auth().get_access_token_from_request = Mock( | |||
# Type ignore: mypy doesn't like us assigning to methods. | |||
self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment] | |||
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] | |||
self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment] | |||
return_value="1234" | |||
) | |||
@@ -417,7 +416,7 @@ class HomeserverTestCase(TestCase): | |||
path: Union[bytes, str], | |||
content: Union[bytes, str, JsonDict] = b"", | |||
access_token: Optional[str] = None, | |||
request: Type[T] = SynapseRequest, | |||
request: Type[Request] = SynapseRequest, | |||
shorthand: bool = True, | |||
federation_auth_origin: Optional[bytes] = None, | |||
content_is_form: bool = False, | |||
@@ -596,7 +595,7 @@ class HomeserverTestCase(TestCase): | |||
nonce_str += b"\x00notadmin" | |||
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) | |||
want_mac = want_mac.hexdigest() | |||
want_mac_digest = want_mac.hexdigest() | |||
body = json.dumps( | |||
{ | |||
@@ -605,7 +604,7 @@ class HomeserverTestCase(TestCase): | |||
"displayname": displayname, | |||
"password": password, | |||
"admin": admin, | |||
"mac": want_mac, | |||
"mac": want_mac_digest, | |||
"inhibit_login": True, | |||
} | |||
) | |||