And do not allow untyped defs in tests.handlers.tags/v1.75.0rc1
@@ -0,0 +1 @@ | |||
Add missing type hints. |
@@ -95,10 +95,7 @@ disallow_untyped_defs = True | |||
[mypy-tests.federation.transport.test_client] | |||
disallow_untyped_defs = True | |||
[mypy-tests.handlers.test_sso] | |||
disallow_untyped_defs = True | |||
[mypy-tests.handlers.test_user_directory] | |||
[mypy-tests.handlers.*] | |||
disallow_untyped_defs = True | |||
[mypy-tests.metrics.test_background_process_metrics] | |||
@@ -2031,7 +2031,7 @@ class PasswordAuthProvider: | |||
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] | |||
# Mapping from login type to login parameters | |||
self._supported_login_types: Dict[str, Iterable[str]] = {} | |||
self._supported_login_types: Dict[str, Tuple[str, ...]] = {} | |||
# Mapping from login type to auth checker callbacks | |||
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {} | |||
@@ -31,7 +31,7 @@ from synapse.appservice import ( | |||
from synapse.handlers.appservice import ApplicationServicesHandler | |||
from synapse.rest.client import login, receipts, register, room, sendtodevice | |||
from synapse.server import HomeServer | |||
from synapse.types import RoomStreamToken | |||
from synapse.types import JsonDict, RoomStreamToken | |||
from synapse.util import Clock | |||
from synapse.util.stringutils import random_string | |||
@@ -44,7 +44,7 @@ from tests.utils import MockClock | |||
class AppServiceHandlerTestCase(unittest.TestCase): | |||
"""Tests the ApplicationServicesHandler.""" | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
self.mock_store = Mock() | |||
self.mock_as_api = Mock() | |||
self.mock_scheduler = Mock() | |||
@@ -61,7 +61,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
self.handler = ApplicationServicesHandler(hs) | |||
self.event_source = hs.get_event_sources() | |||
def test_notify_interested_services(self): | |||
def test_notify_interested_services(self) -> None: | |||
interested_service = self._mkservice(is_interested_in_event=True) | |||
services = [ | |||
self._mkservice(is_interested_in_event=False), | |||
@@ -90,7 +90,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
interested_service, events=[event] | |||
) | |||
def test_query_user_exists_unknown_user(self): | |||
def test_query_user_exists_unknown_user(self) -> None: | |||
user_id = "@someone:anywhere" | |||
services = [self._mkservice(is_interested_in_event=True)] | |||
services[0].is_interested_in_user.return_value = True | |||
@@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) | |||
def test_query_user_exists_known_user(self): | |||
def test_query_user_exists_known_user(self) -> None: | |||
user_id = "@someone:anywhere" | |||
services = [self._mkservice(is_interested_in_event=True)] | |||
services[0].is_interested_in_user.return_value = True | |||
@@ -127,7 +127,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
"query_user called when it shouldn't have been.", | |||
) | |||
def test_query_room_alias_exists(self): | |||
def test_query_room_alias_exists(self) -> None: | |||
room_alias_str = "#foo:bar" | |||
room_alias = Mock() | |||
room_alias.to_string.return_value = room_alias_str | |||
@@ -157,7 +157,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
self.assertEqual(result.room_id, room_id) | |||
self.assertEqual(result.servers, servers) | |||
def test_get_3pe_protocols_no_appservices(self): | |||
def test_get_3pe_protocols_no_appservices(self) -> None: | |||
self.mock_store.get_app_services.return_value = [] | |||
response = self.successResultOf( | |||
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) | |||
@@ -165,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
self.mock_as_api.get_3pe_protocol.assert_not_called() | |||
self.assertEqual(response, {}) | |||
def test_get_3pe_protocols_no_protocols(self): | |||
def test_get_3pe_protocols_no_protocols(self) -> None: | |||
service = self._mkservice(False, []) | |||
self.mock_store.get_app_services.return_value = [service] | |||
response = self.successResultOf( | |||
@@ -174,7 +174,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
self.mock_as_api.get_3pe_protocol.assert_not_called() | |||
self.assertEqual(response, {}) | |||
def test_get_3pe_protocols_protocol_no_response(self): | |||
def test_get_3pe_protocols_protocol_no_response(self) -> None: | |||
service = self._mkservice(False, ["my-protocol"]) | |||
self.mock_store.get_app_services.return_value = [service] | |||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) | |||
@@ -186,7 +186,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
) | |||
self.assertEqual(response, {}) | |||
def test_get_3pe_protocols_select_one_protocol(self): | |||
def test_get_3pe_protocols_select_one_protocol(self) -> None: | |||
service = self._mkservice(False, ["my-protocol"]) | |||
self.mock_store.get_app_services.return_value = [service] | |||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( | |||
@@ -202,7 +202,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} | |||
) | |||
def test_get_3pe_protocols_one_protocol(self): | |||
def test_get_3pe_protocols_one_protocol(self) -> None: | |||
service = self._mkservice(False, ["my-protocol"]) | |||
self.mock_store.get_app_services.return_value = [service] | |||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( | |||
@@ -218,7 +218,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} | |||
) | |||
def test_get_3pe_protocols_multiple_protocol(self): | |||
def test_get_3pe_protocols_multiple_protocol(self) -> None: | |||
service_one = self._mkservice(False, ["my-protocol"]) | |||
service_two = self._mkservice(False, ["other-protocol"]) | |||
self.mock_store.get_app_services.return_value = [service_one, service_two] | |||
@@ -237,11 +237,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
}, | |||
) | |||
def test_get_3pe_protocols_multiple_info(self): | |||
def test_get_3pe_protocols_multiple_info(self) -> None: | |||
service_one = self._mkservice(False, ["my-protocol"]) | |||
service_two = self._mkservice(False, ["my-protocol"]) | |||
async def get_3pe_protocol(service, unusedProtocol): | |||
async def get_3pe_protocol( | |||
service: ApplicationService, protocol: str | |||
) -> Optional[JsonDict]: | |||
if service == service_one: | |||
return { | |||
"x-protocol-data": 42, | |||
@@ -276,7 +278,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
}, | |||
) | |||
def test_notify_interested_services_ephemeral(self): | |||
def test_notify_interested_services_ephemeral(self) -> None: | |||
""" | |||
Test sending ephemeral events to the appservice handler are scheduled | |||
to be pushed out to interested appservices, and that the stream ID is | |||
@@ -306,7 +308,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
580, | |||
) | |||
def test_notify_interested_services_ephemeral_out_of_order(self): | |||
def test_notify_interested_services_ephemeral_out_of_order(self) -> None: | |||
""" | |||
Test sending out of order ephemeral events to the appservice handler | |||
are ignored. | |||
@@ -390,7 +392,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||
receipts.register_servlets, | |||
] | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.hs = hs | |||
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that | |||
# we can track any outgoing ephemeral events | |||
@@ -417,7 +419,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||
"exclusive_as_user", "password", self.exclusive_as_user_device_id | |||
) | |||
def _notify_interested_services(self): | |||
def _notify_interested_services(self) -> None: | |||
# This is normally set in `notify_interested_services` but we need to call the | |||
# internal async version so the reactor gets pushed to completion. | |||
self.hs.get_application_service_handler().current_max += 1 | |||
@@ -443,7 +445,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||
) | |||
def test_match_interesting_room_members( | |||
self, interesting_user: str, should_notify: bool | |||
): | |||
) -> None: | |||
""" | |||
Test to make sure that a interesting user (local or remote) in the room is | |||
notified as expected when someone else in the room sends a message. | |||
@@ -512,7 +514,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||
else: | |||
self.send_mock.assert_not_called() | |||
def test_application_services_receive_events_sent_by_interesting_local_user(self): | |||
def test_application_services_receive_events_sent_by_interesting_local_user( | |||
self, | |||
) -> None: | |||
""" | |||
Test to make sure that a messages sent from a local user can be interesting and | |||
picked up by the appservice. | |||
@@ -568,7 +572,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(events[0]["type"], "m.room.message") | |||
self.assertEqual(events[0]["sender"], alice) | |||
def test_sending_read_receipt_batches_to_application_services(self): | |||
def test_sending_read_receipt_batches_to_application_services(self) -> None: | |||
"""Tests that a large batch of read receipts are sent correctly to | |||
interested application services. | |||
""" | |||
@@ -644,7 +648,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||
@unittest.override_config( | |||
{"experimental_features": {"msc2409_to_device_messages_enabled": True}} | |||
) | |||
def test_application_services_receive_local_to_device(self): | |||
def test_application_services_receive_local_to_device(self) -> None: | |||
""" | |||
Test that when a user sends a to-device message to another user | |||
that is an application service's user namespace, the | |||
@@ -722,7 +726,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||
@unittest.override_config( | |||
{"experimental_features": {"msc2409_to_device_messages_enabled": True}} | |||
) | |||
def test_application_services_receive_bursts_of_to_device(self): | |||
def test_application_services_receive_bursts_of_to_device(self) -> None: | |||
""" | |||
Test that when a user sends >100 to-device messages at once, any | |||
interested AS's will receive them in separate transactions. | |||
@@ -913,7 +917,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) | |||
experimental_feature_enabled: bool, | |||
as_supports_txn_extensions: bool, | |||
as_should_receive_device_list_updates: bool, | |||
): | |||
) -> None: | |||
""" | |||
Tests that an application service receives notice of changed device | |||
lists for a user, when a user changes their device lists. | |||
@@ -1070,7 +1074,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): | |||
and a room for the users to talk in. | |||
""" | |||
async def preparation(): | |||
async def preparation() -> None: | |||
await self._add_otks_for_device(self._sender_user, self._sender_device, 42) | |||
await self._add_fallback_key_for_device( | |||
self._sender_user, self._sender_device, used=True | |||
@@ -199,7 +199,7 @@ class CasHandlerTestCase(HomeserverTestCase): | |||
) | |||
def _mock_request(): | |||
def _mock_request() -> Mock: | |||
"""Returns a mock which will stand in as a SynapseRequest""" | |||
mock = Mock( | |||
spec=[ | |||
@@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor | |||
import synapse.api.errors | |||
import synapse.rest.admin | |||
from synapse.api.constants import EventTypes | |||
from synapse.events import EventBase | |||
from synapse.rest.client import directory, login, room | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict, RoomAlias, create_requester | |||
@@ -201,7 +202,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): | |||
self.test_user_tok = self.login("user", "pass") | |||
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) | |||
def _create_alias(self, user) -> None: | |||
def _create_alias(self, user: str) -> None: | |||
# Create a new alias to this room. | |||
self.get_success( | |||
self.store.create_room_alias_association( | |||
@@ -324,7 +325,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): | |||
) | |||
return room_alias | |||
def _set_canonical_alias(self, content) -> None: | |||
def _set_canonical_alias(self, content: JsonDict) -> None: | |||
"""Configure the canonical alias state on the room.""" | |||
self.helper.send_state( | |||
self.room_id, | |||
@@ -333,13 +334,15 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): | |||
tok=self.admin_user_tok, | |||
) | |||
def _get_canonical_alias(self): | |||
def _get_canonical_alias(self) -> EventBase: | |||
"""Get the canonical alias state of the room.""" | |||
return self.get_success( | |||
result = self.get_success( | |||
self._storage_controllers.state.get_current_state_event( | |||
self.room_id, EventTypes.CanonicalAlias, "" | |||
) | |||
) | |||
assert result is not None | |||
return result | |||
def test_remove_alias(self) -> None: | |||
"""Removing an alias that is the canonical alias should remove it there too.""" | |||
@@ -349,8 +352,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): | |||
) | |||
data = self._get_canonical_alias() | |||
self.assertEqual(data["content"]["alias"], self.test_alias) | |||
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) | |||
self.assertEqual(data.content["alias"], self.test_alias) | |||
self.assertEqual(data.content["alt_aliases"], [self.test_alias]) | |||
# Finally, delete the alias. | |||
self.get_success( | |||
@@ -360,8 +363,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): | |||
) | |||
data = self._get_canonical_alias() | |||
self.assertNotIn("alias", data["content"]) | |||
self.assertNotIn("alt_aliases", data["content"]) | |||
self.assertNotIn("alias", data.content) | |||
self.assertNotIn("alt_aliases", data.content) | |||
def test_remove_other_alias(self) -> None: | |||
"""Removing an alias listed as in alt_aliases should remove it there too.""" | |||
@@ -378,9 +381,9 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): | |||
) | |||
data = self._get_canonical_alias() | |||
self.assertEqual(data["content"]["alias"], self.test_alias) | |||
self.assertEqual(data.content["alias"], self.test_alias) | |||
self.assertEqual( | |||
data["content"]["alt_aliases"], [self.test_alias, other_test_alias] | |||
data.content["alt_aliases"], [self.test_alias, other_test_alias] | |||
) | |||
# Delete the second alias. | |||
@@ -391,8 +394,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): | |||
) | |||
data = self._get_canonical_alias() | |||
self.assertEqual(data["content"]["alias"], self.test_alias) | |||
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias]) | |||
self.assertEqual(data.content["alias"], self.test_alias) | |||
self.assertEqual(data.content["alt_aliases"], [self.test_alias]) | |||
class TestCreateAliasACL(unittest.HomeserverTestCase): | |||
@@ -17,7 +17,11 @@ | |||
import copy | |||
from unittest import mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.errors import SynapseError | |||
from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests import unittest | |||
@@ -39,14 +43,14 @@ room_keys = { | |||
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
return self.setup_test_homeserver(replication_layer=mock.Mock()) | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.handler = hs.get_e2e_room_keys_handler() | |||
self.local_user = "@boris:" + hs.hostname | |||
def test_get_missing_current_version_info(self): | |||
def test_get_missing_current_version_info(self) -> None: | |||
"""Check that we get a 404 if we ask for info about the current version | |||
if there is no version. | |||
""" | |||
@@ -56,7 +60,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
res = e.value.code | |||
self.assertEqual(res, 404) | |||
def test_get_missing_version_info(self): | |||
def test_get_missing_version_info(self) -> None: | |||
"""Check that we get a 404 if we ask for info about a specific version | |||
if it doesn't exist. | |||
""" | |||
@@ -67,9 +71,9 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
res = e.value.code | |||
self.assertEqual(res, 404) | |||
def test_create_version(self): | |||
def test_create_version(self) -> None: | |||
"""Check that we can create and then retrieve versions.""" | |||
res = self.get_success( | |||
version = self.get_success( | |||
self.handler.create_version( | |||
self.local_user, | |||
{ | |||
@@ -78,7 +82,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
) | |||
self.assertEqual(res, "1") | |||
self.assertEqual(version, "1") | |||
# check we can retrieve it as the current version | |||
res = self.get_success(self.handler.get_version_info(self.local_user)) | |||
@@ -110,7 +114,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
) | |||
# upload a new one... | |||
res = self.get_success( | |||
version = self.get_success( | |||
self.handler.create_version( | |||
self.local_user, | |||
{ | |||
@@ -119,7 +123,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
) | |||
self.assertEqual(res, "2") | |||
self.assertEqual(version, "2") | |||
# check we can retrieve it as the current version | |||
res = self.get_success(self.handler.get_version_info(self.local_user)) | |||
@@ -134,7 +138,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
def test_update_version(self): | |||
def test_update_version(self) -> None: | |||
"""Check that we can update versions.""" | |||
version = self.get_success( | |||
self.handler.create_version( | |||
@@ -173,7 +177,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
def test_update_missing_version(self): | |||
def test_update_missing_version(self) -> None: | |||
"""Check that we get a 404 on updating nonexistent versions""" | |||
e = self.get_failure( | |||
self.handler.update_version( | |||
@@ -190,7 +194,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
res = e.value.code | |||
self.assertEqual(res, 404) | |||
def test_update_omitted_version(self): | |||
def test_update_omitted_version(self) -> None: | |||
"""Check that the update succeeds if the version is missing from the body""" | |||
version = self.get_success( | |||
self.handler.create_version( | |||
@@ -227,7 +231,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
def test_update_bad_version(self): | |||
def test_update_bad_version(self) -> None: | |||
"""Check that we get a 400 if the version in the body doesn't match""" | |||
version = self.get_success( | |||
self.handler.create_version( | |||
@@ -255,7 +259,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
res = e.value.code | |||
self.assertEqual(res, 400) | |||
def test_delete_missing_version(self): | |||
def test_delete_missing_version(self) -> None: | |||
"""Check that we get a 404 on deleting nonexistent versions""" | |||
e = self.get_failure( | |||
self.handler.delete_version(self.local_user, "1"), SynapseError | |||
@@ -263,15 +267,15 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
res = e.value.code | |||
self.assertEqual(res, 404) | |||
def test_delete_missing_current_version(self): | |||
def test_delete_missing_current_version(self) -> None: | |||
"""Check that we get a 404 on deleting nonexistent current version""" | |||
e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError) | |||
res = e.value.code | |||
self.assertEqual(res, 404) | |||
def test_delete_version(self): | |||
def test_delete_version(self) -> None: | |||
"""Check that we can create and then delete versions.""" | |||
res = self.get_success( | |||
version = self.get_success( | |||
self.handler.create_version( | |||
self.local_user, | |||
{ | |||
@@ -280,7 +284,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
) | |||
self.assertEqual(res, "1") | |||
self.assertEqual(version, "1") | |||
# check we can delete it | |||
self.get_success(self.handler.delete_version(self.local_user, "1")) | |||
@@ -292,7 +296,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
res = e.value.code | |||
self.assertEqual(res, 404) | |||
def test_get_missing_backup(self): | |||
def test_get_missing_backup(self) -> None: | |||
"""Check that we get a 404 on querying missing backup""" | |||
e = self.get_failure( | |||
self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError | |||
@@ -300,7 +304,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
res = e.value.code | |||
self.assertEqual(res, 404) | |||
def test_get_missing_room_keys(self): | |||
def test_get_missing_room_keys(self) -> None: | |||
"""Check we get an empty response from an empty backup""" | |||
version = self.get_success( | |||
self.handler.create_version( | |||
@@ -319,7 +323,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
# TODO: test the locking semantics when uploading room_keys, | |||
# although this is probably best done in sytest | |||
def test_upload_room_keys_no_versions(self): | |||
def test_upload_room_keys_no_versions(self) -> None: | |||
"""Check that we get a 404 on uploading keys when no versions are defined""" | |||
e = self.get_failure( | |||
self.handler.upload_room_keys(self.local_user, "no_version", room_keys), | |||
@@ -328,7 +332,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
res = e.value.code | |||
self.assertEqual(res, 404) | |||
def test_upload_room_keys_bogus_version(self): | |||
def test_upload_room_keys_bogus_version(self) -> None: | |||
"""Check that we get a 404 on uploading keys when an nonexistent version | |||
is specified | |||
""" | |||
@@ -350,7 +354,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
res = e.value.code | |||
self.assertEqual(res, 404) | |||
def test_upload_room_keys_wrong_version(self): | |||
def test_upload_room_keys_wrong_version(self) -> None: | |||
"""Check that we get a 403 on uploading keys for an old version""" | |||
version = self.get_success( | |||
self.handler.create_version( | |||
@@ -380,7 +384,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
res = e.value.code | |||
self.assertEqual(res, 403) | |||
def test_upload_room_keys_insert(self): | |||
def test_upload_room_keys_insert(self) -> None: | |||
"""Check that we can insert and retrieve keys for a session""" | |||
version = self.get_success( | |||
self.handler.create_version( | |||
@@ -416,7 +420,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertDictEqual(res, room_keys) | |||
def test_upload_room_keys_merge(self): | |||
def test_upload_room_keys_merge(self) -> None: | |||
"""Check that we can upload a new room_key for an existing session and | |||
have it correctly merged""" | |||
version = self.get_success( | |||
@@ -449,9 +453,11 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
self.handler.upload_room_keys(self.local_user, version, new_room_keys) | |||
) | |||
res = self.get_success(self.handler.get_room_keys(self.local_user, version)) | |||
res_keys = self.get_success( | |||
self.handler.get_room_keys(self.local_user, version) | |||
) | |||
self.assertEqual( | |||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], | |||
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], | |||
"SSBBTSBBIEZJU0gK", | |||
) | |||
@@ -465,9 +471,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
self.handler.upload_room_keys(self.local_user, version, new_room_keys) | |||
) | |||
res = self.get_success(self.handler.get_room_keys(self.local_user, version)) | |||
res_keys = self.get_success( | |||
self.handler.get_room_keys(self.local_user, version) | |||
) | |||
self.assertEqual( | |||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" | |||
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], | |||
"new", | |||
) | |||
# the etag should NOT be equal now, since the key changed | |||
@@ -483,9 +492,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
self.handler.upload_room_keys(self.local_user, version, new_room_keys) | |||
) | |||
res = self.get_success(self.handler.get_room_keys(self.local_user, version)) | |||
res_keys = self.get_success( | |||
self.handler.get_room_keys(self.local_user, version) | |||
) | |||
self.assertEqual( | |||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" | |||
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], | |||
"new", | |||
) | |||
# the etag should be the same since the session did not change | |||
@@ -494,7 +506,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
# TODO: check edge cases as well as the common variations here | |||
def test_delete_room_keys(self): | |||
def test_delete_room_keys(self) -> None: | |||
"""Check that we can insert and delete keys for a session""" | |||
version = self.get_success( | |||
self.handler.create_version( | |||
@@ -439,7 +439,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): | |||
user_id = self.register_user("kermit", "test") | |||
tok = self.login("kermit", "test") | |||
def create_invite(): | |||
def create_invite() -> EventBase: | |||
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) | |||
room_version = self.get_success(self.store.get_room_version(room_id)) | |||
return event_from_pdu_json( | |||
@@ -14,6 +14,8 @@ | |||
from typing import Optional | |||
from unittest import mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.errors import AuthError, StoreError | |||
from synapse.api.room_versions import RoomVersion | |||
from synapse.event_auth import ( | |||
@@ -26,8 +28,10 @@ from synapse.federation.transport.client import StateRequestResponse | |||
from synapse.logging.context import LoggingContext | |||
from synapse.rest import admin | |||
from synapse.rest.client import login, room | |||
from synapse.server import HomeServer | |||
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort | |||
from synapse.types import JsonDict | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import event_injection, make_awaitable | |||
@@ -40,7 +44,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||
room.register_servlets, | |||
] | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
# mock out the federation transport client | |||
self.mock_federation_transport_client = mock.Mock( | |||
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] | |||
@@ -165,7 +169,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||
) | |||
else: | |||
async def get_event(destination: str, event_id: str, timeout=None): | |||
async def get_event( | |||
destination: str, event_id: str, timeout: Optional[int] = None | |||
) -> JsonDict: | |||
self.assertEqual(destination, self.OTHER_SERVER_NAME) | |||
self.assertEqual(event_id, prev_event.event_id) | |||
return {"pdus": [prev_event.get_pdu_json()]} | |||
@@ -14,12 +14,16 @@ | |||
import logging | |||
from typing import Tuple | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import EventTypes | |||
from synapse.events import EventBase | |||
from synapse.events.snapshot import EventContext | |||
from synapse.rest import admin | |||
from synapse.rest.client import login, room | |||
from synapse.server import HomeServer | |||
from synapse.types import create_requester | |||
from synapse.util import Clock | |||
from synapse.util.stringutils import random_string | |||
from tests import unittest | |||
@@ -35,7 +39,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||
room.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.handler = self.hs.get_event_creation_handler() | |||
self._persist_event_storage_controller = ( | |||
self.hs.get_storage_controllers().persistence | |||
@@ -94,7 +98,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
def test_duplicated_txn_id(self): | |||
def test_duplicated_txn_id(self) -> None: | |||
"""Test that attempting to handle/persist an event with a transaction ID | |||
that has already been persisted correctly returns the old event and does | |||
*not* produce duplicate messages. | |||
@@ -161,7 +165,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||
# rather than the new one. | |||
self.assertEqual(ret_event1.event_id, ret_event4.event_id) | |||
def test_duplicated_txn_id_one_call(self): | |||
def test_duplicated_txn_id_one_call(self) -> None: | |||
"""Test that we correctly handle duplicates that we try and persist at | |||
the same time. | |||
""" | |||
@@ -185,7 +189,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(len(events), 2) | |||
self.assertEqual(events[0].event_id, events[1].event_id) | |||
def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self): | |||
def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events( | |||
self, | |||
) -> None: | |||
"""When we set allow_no_prev_events=True, should be able to create a | |||
event without any prev_events (only auth_events). | |||
""" | |||
@@ -214,7 +220,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||
def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events( | |||
self, | |||
): | |||
) -> None: | |||
"""When we set allow_no_prev_events=False, shouldn't be able to create a | |||
event without any prev_events even if it has auth_events. Expect an | |||
exception to be raised. | |||
@@ -245,7 +251,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||
def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events( | |||
self, | |||
): | |||
) -> None: | |||
"""When we set allow_no_prev_events=True, should be able to create a | |||
event without any prev_events or auth_events. Expect an exception to be | |||
raised. | |||
@@ -277,12 +283,12 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): | |||
room.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.user_id = self.register_user("tester", "foobar") | |||
self.access_token = self.login("tester", "foobar") | |||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) | |||
def test_allow_server_acl(self): | |||
def test_allow_server_acl(self) -> None: | |||
"""Test that sending an ACL that blocks everyone but ourselves works.""" | |||
self.helper.send_state( | |||
@@ -293,7 +299,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): | |||
expect_code=200, | |||
) | |||
def test_deny_server_acl_block_outselves(self): | |||
def test_deny_server_acl_block_outselves(self) -> None: | |||
"""Test that sending an ACL that blocks ourselves does not work.""" | |||
self.helper.send_state( | |||
self.room_id, | |||
@@ -303,7 +309,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase): | |||
expect_code=400, | |||
) | |||
def test_deny_redact_server_acl(self): | |||
def test_deny_redact_server_acl(self) -> None: | |||
"""Test that attempting to redact an ACL is blocked.""" | |||
body = self.helper.send_state( | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import os | |||
from typing import Any, Dict, Tuple | |||
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple | |||
from unittest.mock import ANY, Mock, patch | |||
from urllib.parse import parse_qs, urlparse | |||
@@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.handlers.sso import MappingException | |||
from synapse.http.site import SynapseRequest | |||
from synapse.server import HomeServer | |||
from synapse.types import UserID | |||
from synapse.types import JsonDict, UserID | |||
from synapse.util import Clock | |||
from synapse.util.macaroons import get_value_from_macaroon | |||
from synapse.util.stringutils import random_string | |||
@@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config | |||
try: | |||
import authlib # noqa: F401 | |||
from authlib.oidc.core import UserInfo | |||
from authlib.oidc.discovery import OpenIDProviderMetadata | |||
from synapse.handlers.oidc import Token, UserAttributeDict | |||
HAS_OIDC = True | |||
except ImportError: | |||
@@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = { | |||
class TestMappingProvider: | |||
@staticmethod | |||
def parse_config(config): | |||
return | |||
def parse_config(config: JsonDict) -> None: | |||
return None | |||
def __init__(self, config): | |||
def __init__(self, config: None): | |||
pass | |||
def get_remote_user_id(self, userinfo): | |||
def get_remote_user_id(self, userinfo: "UserInfo") -> str: | |||
return userinfo["sub"] | |||
async def map_user_attributes(self, userinfo, token): | |||
return {"localpart": userinfo["username"], "display_name": None} | |||
async def map_user_attributes( | |||
self, userinfo: "UserInfo", token: "Token" | |||
) -> "UserAttributeDict": | |||
# This is testing not providing the full map. | |||
return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item] | |||
# Do not include get_extra_attributes to test backwards compatibility paths. | |||
class TestMappingProviderExtra(TestMappingProvider): | |||
async def get_extra_attributes(self, userinfo, token): | |||
async def get_extra_attributes( | |||
self, userinfo: "UserInfo", token: "Token" | |||
) -> JsonDict: | |||
return {"phone": userinfo["phone"]} | |||
class TestMappingProviderFailures(TestMappingProvider): | |||
async def map_user_attributes(self, userinfo, token, failures): | |||
return { | |||
# Superclass is testing the legacy interface for map_user_attributes. | |||
async def map_user_attributes( # type: ignore[override] | |||
self, userinfo: "UserInfo", token: "Token", failures: int | |||
) -> "UserAttributeDict": | |||
return { # type: ignore[typeddict-item] | |||
"localpart": userinfo["username"] + (str(failures) if failures else ""), | |||
"display_name": None, | |||
} | |||
@@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase): | |||
self.hs_patcher.stop() | |||
return super().tearDown() | |||
def reset_mocks(self): | |||
def reset_mocks(self) -> None: | |||
"""Reset all the Mocks.""" | |||
self.fake_server.reset_mocks() | |||
self.render_error.reset_mock() | |||
self.complete_sso_login.reset_mock() | |||
def metadata_edit(self, values): | |||
def metadata_edit(self, values: dict) -> ContextManager[Mock]: | |||
"""Modify the result that will be returned by the well-known query""" | |||
metadata = self.fake_server.get_metadata() | |||
@@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase): | |||
session = self._generate_oidc_session_token(state, nonce, client_redirect_url) | |||
return _build_callback_request(code, state, session), grant | |||
def assertRenderedError(self, error, error_description=None): | |||
def assertRenderedError( | |||
self, error: str, error_description: Optional[str] = None | |||
) -> Tuple[Any, ...]: | |||
self.render_error.assert_called_once() | |||
args = self.render_error.call_args[0] | |||
self.assertEqual(args[1], error) | |||
@@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase): | |||
"""Provider metadatas are extensively validated.""" | |||
h = self.provider | |||
def force_load_metadata(): | |||
async def force_load(): | |||
def force_load_metadata() -> Awaitable[None]: | |||
async def force_load() -> "OpenIDProviderMetadata": | |||
return await h.load_metadata(force=True) | |||
return get_awaitable_result(force_load()) | |||
@@ -1198,7 +1212,7 @@ def _build_callback_request( | |||
state: str, | |||
session: str, | |||
ip_address: str = "10.0.0.1", | |||
): | |||
) -> Mock: | |||
"""Builds a fake SynapseRequest to mock the browser callback | |||
Returns a Mock object which looks like the SynapseRequest we get from a browser | |||
@@ -15,12 +15,13 @@ | |||
"""Tests for the password_auth_provider interface""" | |||
from http import HTTPStatus | |||
from typing import Any, Type, Union | |||
from typing import Any, Dict, List, Optional, Type, Union | |||
from unittest.mock import Mock | |||
import synapse | |||
from synapse.api.constants import LoginType | |||
from synapse.api.errors import Codes | |||
from synapse.handlers.account import AccountHandler | |||
from synapse.module_api import ModuleApi | |||
from synapse.rest.client import account, devices, login, logout, register | |||
from synapse.types import JsonDict, UserID | |||
@@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider: | |||
"""A legacy password_provider which only implements `check_password`.""" | |||
@staticmethod | |||
def parse_config(self): | |||
def parse_config(config: JsonDict) -> None: | |||
pass | |||
def __init__(self, config, account_handler): | |||
def __init__(self, config: None, account_handler: AccountHandler): | |||
pass | |||
def check_password(self, *args): | |||
def check_password(self, *args: str) -> Mock: | |||
return mock_password_provider.check_password(*args) | |||
@@ -58,16 +59,16 @@ class LegacyCustomAuthProvider: | |||
"""A legacy password_provider which implements a custom login type.""" | |||
@staticmethod | |||
def parse_config(self): | |||
def parse_config(config: JsonDict) -> None: | |||
pass | |||
def __init__(self, config, account_handler): | |||
def __init__(self, config: None, account_handler: AccountHandler): | |||
pass | |||
def get_supported_login_types(self): | |||
def get_supported_login_types(self) -> Dict[str, List[str]]: | |||
return {"test.login_type": ["test_field"]} | |||
def check_auth(self, *args): | |||
def check_auth(self, *args: str) -> Mock: | |||
return mock_password_provider.check_auth(*args) | |||
@@ -75,15 +76,15 @@ class CustomAuthProvider: | |||
"""A module which registers password_auth_provider callbacks for a custom login type.""" | |||
@staticmethod | |||
def parse_config(self): | |||
def parse_config(config: JsonDict) -> None: | |||
pass | |||
def __init__(self, config, api: ModuleApi): | |||
def __init__(self, config: None, api: ModuleApi): | |||
api.register_password_auth_provider_callbacks( | |||
auth_checkers={("test.login_type", ("test_field",)): self.check_auth} | |||
) | |||
def check_auth(self, *args): | |||
def check_auth(self, *args: Any) -> Mock: | |||
return mock_password_provider.check_auth(*args) | |||
@@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider: | |||
as a custom type.""" | |||
@staticmethod | |||
def parse_config(self): | |||
def parse_config(config: JsonDict) -> None: | |||
pass | |||
def __init__(self, config, account_handler): | |||
def __init__(self, config: None, account_handler: AccountHandler): | |||
pass | |||
def get_supported_login_types(self): | |||
def get_supported_login_types(self) -> Dict[str, List[str]]: | |||
return {"m.login.password": ["password"], "test.login_type": ["test_field"]} | |||
def check_auth(self, *args): | |||
def check_auth(self, *args: str) -> Mock: | |||
return mock_password_provider.check_auth(*args) | |||
@@ -110,10 +111,10 @@ class PasswordCustomAuthProvider: | |||
as well as a password login""" | |||
@staticmethod | |||
def parse_config(self): | |||
def parse_config(config: JsonDict) -> None: | |||
pass | |||
def __init__(self, config, api: ModuleApi): | |||
def __init__(self, config: None, api: ModuleApi): | |||
api.register_password_auth_provider_callbacks( | |||
auth_checkers={ | |||
("test.login_type", ("test_field",)): self.check_auth, | |||
@@ -121,10 +122,10 @@ class PasswordCustomAuthProvider: | |||
} | |||
) | |||
def check_auth(self, *args): | |||
def check_auth(self, *args: Any) -> Mock: | |||
return mock_password_provider.check_auth(*args) | |||
def check_pass(self, *args): | |||
def check_pass(self, *args: str) -> Mock: | |||
return mock_password_provider.check_password(*args) | |||
@@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
CALLBACK_USERNAME = "get_username_for_registration" | |||
CALLBACK_DISPLAYNAME = "get_displayname_for_registration" | |||
def setUp(self): | |||
def setUp(self) -> None: | |||
# we use a global mock device, so make sure we are starting with a clean slate | |||
mock_password_provider.reset_mock() | |||
super().setUp() | |||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) | |||
def test_password_only_auth_progiver_login_legacy(self): | |||
def test_password_only_auth_progiver_login_legacy(self) -> None: | |||
self.password_only_auth_provider_login_test_body() | |||
def password_only_auth_provider_login_test_body(self): | |||
def password_only_auth_provider_login_test_body(self) -> None: | |||
# login flows should only have m.login.password | |||
flows = self._get_login_flows() | |||
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) | |||
@@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
) | |||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) | |||
def test_password_only_auth_provider_ui_auth_legacy(self): | |||
def test_password_only_auth_provider_ui_auth_legacy(self) -> None: | |||
self.password_only_auth_provider_ui_auth_test_body() | |||
def password_only_auth_provider_ui_auth_test_body(self): | |||
def password_only_auth_provider_ui_auth_test_body(self) -> None: | |||
"""UI Auth should delegate correctly to the password provider""" | |||
# create the user, otherwise access doesn't work | |||
@@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p") | |||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) | |||
def test_local_user_fallback_login_legacy(self): | |||
def test_local_user_fallback_login_legacy(self) -> None: | |||
self.local_user_fallback_login_test_body() | |||
def local_user_fallback_login_test_body(self): | |||
def local_user_fallback_login_test_body(self) -> None: | |||
"""rejected login should fall back to local db""" | |||
self.register_user("localuser", "localpass") | |||
@@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.assertEqual("@localuser:test", channel.json_body["user_id"]) | |||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) | |||
def test_local_user_fallback_ui_auth_legacy(self): | |||
def test_local_user_fallback_ui_auth_legacy(self) -> None: | |||
self.local_user_fallback_ui_auth_test_body() | |||
def local_user_fallback_ui_auth_test_body(self): | |||
def local_user_fallback_ui_auth_test_body(self) -> None: | |||
"""rejected login should fall back to local db""" | |||
self.register_user("localuser", "localpass") | |||
@@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"localdb_enabled": False}, | |||
} | |||
) | |||
def test_no_local_user_fallback_login_legacy(self): | |||
def test_no_local_user_fallback_login_legacy(self) -> None: | |||
self.no_local_user_fallback_login_test_body() | |||
def no_local_user_fallback_login_test_body(self): | |||
def no_local_user_fallback_login_test_body(self) -> None: | |||
"""localdb_enabled can block login with the local password""" | |||
self.register_user("localuser", "localpass") | |||
@@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"localdb_enabled": False}, | |||
} | |||
) | |||
def test_no_local_user_fallback_ui_auth_legacy(self): | |||
def test_no_local_user_fallback_ui_auth_legacy(self) -> None: | |||
self.no_local_user_fallback_ui_auth_test_body() | |||
def no_local_user_fallback_ui_auth_test_body(self): | |||
def no_local_user_fallback_ui_auth_test_body(self) -> None: | |||
"""localdb_enabled can block ui auth with the local password""" | |||
self.register_user("localuser", "localpass") | |||
@@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"enabled": False}, | |||
} | |||
) | |||
def test_password_auth_disabled_legacy(self): | |||
def test_password_auth_disabled_legacy(self) -> None: | |||
self.password_auth_disabled_test_body() | |||
def password_auth_disabled_test_body(self): | |||
def password_auth_disabled_test_body(self) -> None: | |||
"""password auth doesn't work if it's disabled across the board""" | |||
# login flows should be empty | |||
flows = self._get_login_flows() | |||
@@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
mock_password_provider.check_password.assert_not_called() | |||
@override_config(legacy_providers_config(LegacyCustomAuthProvider)) | |||
def test_custom_auth_provider_login_legacy(self): | |||
def test_custom_auth_provider_login_legacy(self) -> None: | |||
self.custom_auth_provider_login_test_body() | |||
@override_config(providers_config(CustomAuthProvider)) | |||
def test_custom_auth_provider_login(self): | |||
def test_custom_auth_provider_login(self) -> None: | |||
self.custom_auth_provider_login_test_body() | |||
def custom_auth_provider_login_test_body(self): | |||
def custom_auth_provider_login_test_body(self) -> None: | |||
# login flows should have the custom flow and m.login.password, since we | |||
# haven't disabled local password lookup. | |||
# (password must come first, because reasons) | |||
@@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
) | |||
@override_config(legacy_providers_config(LegacyCustomAuthProvider)) | |||
def test_custom_auth_provider_ui_auth_legacy(self): | |||
def test_custom_auth_provider_ui_auth_legacy(self) -> None: | |||
self.custom_auth_provider_ui_auth_test_body() | |||
@override_config(providers_config(CustomAuthProvider)) | |||
def test_custom_auth_provider_ui_auth(self): | |||
def test_custom_auth_provider_ui_auth(self) -> None: | |||
self.custom_auth_provider_ui_auth_test_body() | |||
def custom_auth_provider_ui_auth_test_body(self): | |||
def custom_auth_provider_ui_auth_test_body(self) -> None: | |||
# register the user and log in twice, to get two devices | |||
self.register_user("localuser", "localpass") | |||
tok1 = self.login("localuser", "localpass") | |||
@@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
) | |||
@override_config(legacy_providers_config(LegacyCustomAuthProvider)) | |||
def test_custom_auth_provider_callback_legacy(self): | |||
def test_custom_auth_provider_callback_legacy(self) -> None: | |||
self.custom_auth_provider_callback_test_body() | |||
@override_config(providers_config(CustomAuthProvider)) | |||
def test_custom_auth_provider_callback(self): | |||
def test_custom_auth_provider_callback(self) -> None: | |||
self.custom_auth_provider_callback_test_body() | |||
def custom_auth_provider_callback_test_body(self): | |||
def custom_auth_provider_callback_test_body(self) -> None: | |||
callback = Mock(return_value=make_awaitable(None)) | |||
mock_password_provider.check_auth.return_value = make_awaitable( | |||
@@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"enabled": False}, | |||
} | |||
) | |||
def test_custom_auth_password_disabled_legacy(self): | |||
def test_custom_auth_password_disabled_legacy(self) -> None: | |||
self.custom_auth_password_disabled_test_body() | |||
@override_config( | |||
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} | |||
) | |||
def test_custom_auth_password_disabled(self): | |||
def test_custom_auth_password_disabled(self) -> None: | |||
self.custom_auth_password_disabled_test_body() | |||
def custom_auth_password_disabled_test_body(self): | |||
def custom_auth_password_disabled_test_body(self) -> None: | |||
"""Test login with a custom auth provider where password login is disabled""" | |||
self.register_user("localuser", "localpass") | |||
@@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"enabled": False, "localdb_enabled": False}, | |||
} | |||
) | |||
def test_custom_auth_password_disabled_localdb_enabled_legacy(self): | |||
def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None: | |||
self.custom_auth_password_disabled_localdb_enabled_test_body() | |||
@override_config( | |||
@@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"enabled": False, "localdb_enabled": False}, | |||
} | |||
) | |||
def test_custom_auth_password_disabled_localdb_enabled(self): | |||
def test_custom_auth_password_disabled_localdb_enabled(self) -> None: | |||
self.custom_auth_password_disabled_localdb_enabled_test_body() | |||
def custom_auth_password_disabled_localdb_enabled_test_body(self): | |||
def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None: | |||
"""Check the localdb_enabled == enabled == False | |||
Regression test for https://github.com/matrix-org/synapse/issues/8914: check | |||
@@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"enabled": False}, | |||
} | |||
) | |||
def test_password_custom_auth_password_disabled_login_legacy(self): | |||
def test_password_custom_auth_password_disabled_login_legacy(self) -> None: | |||
self.password_custom_auth_password_disabled_login_test_body() | |||
@override_config( | |||
@@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"enabled": False}, | |||
} | |||
) | |||
def test_password_custom_auth_password_disabled_login(self): | |||
def test_password_custom_auth_password_disabled_login(self) -> None: | |||
self.password_custom_auth_password_disabled_login_test_body() | |||
def password_custom_auth_password_disabled_login_test_body(self): | |||
def password_custom_auth_password_disabled_login_test_body(self) -> None: | |||
"""log in with a custom auth provider which implements password, but password | |||
login is disabled""" | |||
self.register_user("localuser", "localpass") | |||
@@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"enabled": False}, | |||
} | |||
) | |||
def test_password_custom_auth_password_disabled_ui_auth_legacy(self): | |||
def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None: | |||
self.password_custom_auth_password_disabled_ui_auth_test_body() | |||
@override_config( | |||
@@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"enabled": False}, | |||
} | |||
) | |||
def test_password_custom_auth_password_disabled_ui_auth(self): | |||
def test_password_custom_auth_password_disabled_ui_auth(self) -> None: | |||
self.password_custom_auth_password_disabled_ui_auth_test_body() | |||
def password_custom_auth_password_disabled_ui_auth_test_body(self): | |||
def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None: | |||
"""UI Auth with a custom auth provider which implements password, but password | |||
login is disabled""" | |||
# register the user and log in twice via the test login type to get two devices, | |||
@@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"localdb_enabled": False}, | |||
} | |||
) | |||
def test_custom_auth_no_local_user_fallback_legacy(self): | |||
def test_custom_auth_no_local_user_fallback_legacy(self) -> None: | |||
self.custom_auth_no_local_user_fallback_test_body() | |||
@override_config( | |||
@@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"password_config": {"localdb_enabled": False}, | |||
} | |||
) | |||
def test_custom_auth_no_local_user_fallback(self): | |||
def test_custom_auth_no_local_user_fallback(self) -> None: | |||
self.custom_auth_no_local_user_fallback_test_body() | |||
def custom_auth_no_local_user_fallback_test_body(self): | |||
def custom_auth_no_local_user_fallback_test_body(self) -> None: | |||
"""Test login with a custom auth provider where the local db is disabled""" | |||
self.register_user("localuser", "localpass") | |||
@@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
channel = self._send_password_login("localuser", "localpass") | |||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) | |||
def test_on_logged_out(self): | |||
def test_on_logged_out(self) -> None: | |||
"""Tests that the on_logged_out callback is called when the user logs out.""" | |||
self.register_user("rin", "password") | |||
tok = self.login("rin", "password") | |||
self.called = False | |||
async def on_logged_out(user_id, device_id, access_token): | |||
async def on_logged_out( | |||
user_id: str, device_id: Optional[str], access_token: str | |||
) -> None: | |||
self.called = True | |||
on_logged_out = Mock(side_effect=on_logged_out) | |||
@@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
on_logged_out.assert_called_once() | |||
self.assertTrue(self.called) | |||
def test_username(self): | |||
def test_username(self) -> None: | |||
"""Tests that the get_username_for_registration callback can define the username | |||
of a user when registering. | |||
""" | |||
@@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
mxid = channel.json_body["user_id"] | |||
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") | |||
def test_username_uia(self): | |||
def test_username_uia(self) -> None: | |||
"""Tests that the get_username_for_registration callback is only called at the | |||
end of the UIA flow. | |||
""" | |||
@@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
# Set some email configuration so the test doesn't fail because of its absence. | |||
@override_config({"email": {"notif_from": "noreply@test"}}) | |||
def test_3pid_allowed(self): | |||
def test_3pid_allowed(self) -> None: | |||
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse | |||
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind | |||
the 3PID. Also checks that the module is passed a boolean indicating whether the | |||
@@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self._test_3pid_allowed("rin", False) | |||
self._test_3pid_allowed("kitay", True) | |||
def test_displayname(self): | |||
def test_displayname(self) -> None: | |||
"""Tests that the get_displayname_for_registration callback can define the | |||
display name of a user when registering. | |||
""" | |||
@@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.assertEqual(display_name, username + "-foo") | |||
def test_displayname_uia(self): | |||
def test_displayname_uia(self) -> None: | |||
"""Tests that the get_displayname_for_registration callback is only called at the | |||
end of the UIA flow. | |||
""" | |||
@@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
# Check that the callback has been called. | |||
m.assert_called_once() | |||
def _test_3pid_allowed(self, username: str, registration: bool): | |||
def _test_3pid_allowed(self, username: str, registration: bool) -> None: | |||
"""Tests that the "is_3pid_allowed" module callback is called correctly, using | |||
either /register or /account URLs depending on the arguments. | |||
@@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
client is trying to register. | |||
""" | |||
async def callback(uia_results, params): | |||
async def callback(uia_results: JsonDict, params: JsonDict) -> str: | |||
self.assertIn(LoginType.DUMMY, uia_results) | |||
username = params["username"] | |||
return username + "-foo" | |||
@@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
def _send_password_login(self, user: str, password: str) -> FakeChannel: | |||
return self._send_login(type="m.login.password", user=user, password=password) | |||
def _send_login(self, type, user, **params) -> FakeChannel: | |||
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type}) | |||
def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel: | |||
params = {"identifier": {"type": "m.id.user", "user": user}, "type": type} | |||
params.update(extra_params) | |||
channel = self.make_request("POST", "/_matrix/client/r0/login", params) | |||
return channel | |||
def _start_delete_device_session(self, access_token, device_id) -> str: | |||
def _start_delete_device_session(self, access_token: str, device_id: str) -> str: | |||
"""Make an initial delete device request, and return the UI Auth session ID""" | |||
channel = self._delete_device(access_token, device_id) | |||
self.assertEqual(channel.code, 401) | |||
@@ -12,12 +12,14 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Optional | |||
from typing import Optional, cast | |||
from unittest.mock import Mock, call | |||
from parameterized import parameterized | |||
from signedjson.key import generate_signing_key | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import EventTypes, Membership, PresenceState | |||
from synapse.api.presence import UserPresenceState | |||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS | |||
@@ -35,7 +37,9 @@ from synapse.handlers.presence import ( | |||
) | |||
from synapse.rest import admin | |||
from synapse.rest.client import room | |||
from synapse.types import UserID, get_domain_from_id | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict, UserID, get_domain_from_id | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.replication._base import BaseMultiWorkerStreamTestCase | |||
@@ -44,10 +48,12 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase | |||
class PresenceUpdateTestCase(unittest.HomeserverTestCase): | |||
servlets = [admin.register_servlets] | |||
def prepare(self, reactor, clock, homeserver): | |||
def prepare( | |||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | |||
) -> None: | |||
self.store = homeserver.get_datastores().main | |||
def test_offline_to_online(self): | |||
def test_offline_to_online(self) -> None: | |||
wheel_timer = Mock() | |||
user_id = "@foo:bar" | |||
now = 5000000 | |||
@@ -85,7 +91,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): | |||
any_order=True, | |||
) | |||
def test_online_to_online(self): | |||
def test_online_to_online(self) -> None: | |||
wheel_timer = Mock() | |||
user_id = "@foo:bar" | |||
now = 5000000 | |||
@@ -128,7 +134,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): | |||
any_order=True, | |||
) | |||
def test_online_to_online_last_active_noop(self): | |||
def test_online_to_online_last_active_noop(self) -> None: | |||
wheel_timer = Mock() | |||
user_id = "@foo:bar" | |||
now = 5000000 | |||
@@ -173,7 +179,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): | |||
any_order=True, | |||
) | |||
def test_online_to_online_last_active(self): | |||
def test_online_to_online_last_active(self) -> None: | |||
wheel_timer = Mock() | |||
user_id = "@foo:bar" | |||
now = 5000000 | |||
@@ -210,7 +216,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): | |||
any_order=True, | |||
) | |||
def test_remote_ping_timer(self): | |||
def test_remote_ping_timer(self) -> None: | |||
wheel_timer = Mock() | |||
user_id = "@foo:bar" | |||
now = 5000000 | |||
@@ -244,7 +250,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): | |||
any_order=True, | |||
) | |||
def test_online_to_offline(self): | |||
def test_online_to_offline(self) -> None: | |||
wheel_timer = Mock() | |||
user_id = "@foo:bar" | |||
now = 5000000 | |||
@@ -266,7 +272,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(wheel_timer.insert.call_count, 0) | |||
def test_online_to_idle(self): | |||
def test_online_to_idle(self) -> None: | |||
wheel_timer = Mock() | |||
user_id = "@foo:bar" | |||
now = 5000000 | |||
@@ -300,7 +306,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): | |||
any_order=True, | |||
) | |||
def test_persisting_presence_updates(self): | |||
def test_persisting_presence_updates(self) -> None: | |||
"""Tests that the latest presence state for each user is persisted correctly""" | |||
# Create some test users and presence states for them | |||
presence_states = [] | |||
@@ -322,7 +328,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): | |||
self.get_success(self.store.update_presence(presence_states)) | |||
# Check that each update is present in the database | |||
db_presence_states = self.get_success( | |||
db_presence_states_raw = self.get_success( | |||
self.store.get_all_presence_updates( | |||
instance_name="master", | |||
last_id=0, | |||
@@ -332,7 +338,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): | |||
) | |||
# Extract presence update user ID and state information into lists of tuples | |||
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]] | |||
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states_raw[0]] | |||
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states] | |||
# Compare what we put into the storage with what we got out. | |||
@@ -343,7 +349,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): | |||
class PresenceTimeoutTestCase(unittest.TestCase): | |||
"""Tests different timers and that the timer does not change `status_msg` of user.""" | |||
def test_idle_timer(self): | |||
def test_idle_timer(self) -> None: | |||
user_id = "@foo:bar" | |||
status_msg = "I'm here!" | |||
now = 5000000 | |||
@@ -363,7 +369,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): | |||
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE) | |||
self.assertEqual(new_state.status_msg, status_msg) | |||
def test_busy_no_idle(self): | |||
def test_busy_no_idle(self) -> None: | |||
""" | |||
Tests that a user setting their presence to busy but idling doesn't turn their | |||
presence state into unavailable. | |||
@@ -387,7 +393,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): | |||
self.assertEqual(new_state.state, PresenceState.BUSY) | |||
self.assertEqual(new_state.status_msg, status_msg) | |||
def test_sync_timeout(self): | |||
def test_sync_timeout(self) -> None: | |||
user_id = "@foo:bar" | |||
status_msg = "I'm here!" | |||
now = 5000000 | |||
@@ -407,7 +413,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): | |||
self.assertEqual(new_state.state, PresenceState.OFFLINE) | |||
self.assertEqual(new_state.status_msg, status_msg) | |||
def test_sync_online(self): | |||
def test_sync_online(self) -> None: | |||
user_id = "@foo:bar" | |||
status_msg = "I'm here!" | |||
now = 5000000 | |||
@@ -429,7 +435,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): | |||
self.assertEqual(new_state.state, PresenceState.ONLINE) | |||
self.assertEqual(new_state.status_msg, status_msg) | |||
def test_federation_ping(self): | |||
def test_federation_ping(self) -> None: | |||
user_id = "@foo:bar" | |||
status_msg = "I'm here!" | |||
now = 5000000 | |||
@@ -448,7 +454,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): | |||
self.assertIsNotNone(new_state) | |||
self.assertEqual(state, new_state) | |||
def test_no_timeout(self): | |||
def test_no_timeout(self) -> None: | |||
user_id = "@foo:bar" | |||
now = 5000000 | |||
@@ -464,7 +470,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): | |||
self.assertIsNone(new_state) | |||
def test_federation_timeout(self): | |||
def test_federation_timeout(self) -> None: | |||
user_id = "@foo:bar" | |||
status_msg = "I'm here!" | |||
now = 5000000 | |||
@@ -487,7 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): | |||
self.assertEqual(new_state.state, PresenceState.OFFLINE) | |||
self.assertEqual(new_state.status_msg, status_msg) | |||
def test_last_active(self): | |||
def test_last_active(self) -> None: | |||
user_id = "@foo:bar" | |||
status_msg = "I'm here!" | |||
now = 5000000 | |||
@@ -508,15 +514,15 @@ class PresenceTimeoutTestCase(unittest.TestCase): | |||
class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.presence_handler = hs.get_presence_handler() | |||
self.clock = hs.get_clock() | |||
def test_external_process_timeout(self): | |||
def test_external_process_timeout(self) -> None: | |||
"""Test that if an external process doesn't update the records for a while | |||
we time out their syncing users presence. | |||
""" | |||
process_id = 1 | |||
process_id = "1" | |||
user_id = "@test:server" | |||
# Notify handler that a user is now syncing. | |||
@@ -544,7 +550,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
) | |||
self.assertEqual(state.state, PresenceState.OFFLINE) | |||
def test_user_goes_offline_by_timeout_status_msg_remain(self): | |||
def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None: | |||
"""Test that if a user doesn't update the records for a while | |||
users presence goes `OFFLINE` because of timeout and `status_msg` remains. | |||
""" | |||
@@ -576,7 +582,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
self.assertEqual(state.state, PresenceState.OFFLINE) | |||
self.assertEqual(state.status_msg, status_msg) | |||
def test_user_goes_offline_manually_with_no_status_msg(self): | |||
def test_user_goes_offline_manually_with_no_status_msg(self) -> None: | |||
"""Test that if a user change presence manually to `OFFLINE` | |||
and no status is set, that `status_msg` is `None`. | |||
""" | |||
@@ -601,7 +607,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
self.assertEqual(state.state, PresenceState.OFFLINE) | |||
self.assertEqual(state.status_msg, None) | |||
def test_user_goes_offline_manually_with_status_msg(self): | |||
def test_user_goes_offline_manually_with_status_msg(self) -> None: | |||
"""Test that if a user change presence manually to `OFFLINE` | |||
and a status is set, that `status_msg` appears. | |||
""" | |||
@@ -618,7 +624,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
user_id, PresenceState.OFFLINE, "And now here." | |||
) | |||
def test_user_reset_online_with_no_status(self): | |||
def test_user_reset_online_with_no_status(self) -> None: | |||
"""Test that if a user set again the presence manually | |||
and no status is set, that `status_msg` is `None`. | |||
""" | |||
@@ -644,7 +650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
self.assertEqual(state.state, PresenceState.ONLINE) | |||
self.assertEqual(state.status_msg, None) | |||
def test_set_presence_with_status_msg_none(self): | |||
def test_set_presence_with_status_msg_none(self) -> None: | |||
"""Test that if a user set again the presence manually | |||
and status is `None`, that `status_msg` is `None`. | |||
""" | |||
@@ -659,7 +665,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
# Mark user as online and `status_msg = None` | |||
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) | |||
def test_set_presence_from_syncing_not_set(self): | |||
def test_set_presence_from_syncing_not_set(self) -> None: | |||
"""Test that presence is not set by syncing if affect_presence is false""" | |||
user_id = "@test:server" | |||
status_msg = "I'm here!" | |||
@@ -680,7 +686,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
# and status message should still be the same | |||
self.assertEqual(state.status_msg, status_msg) | |||
def test_set_presence_from_syncing_is_set(self): | |||
def test_set_presence_from_syncing_is_set(self) -> None: | |||
"""Test that presence is set by syncing if affect_presence is true""" | |||
user_id = "@test:server" | |||
status_msg = "I'm here!" | |||
@@ -699,7 +705,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
# we should now be online | |||
self.assertEqual(state.state, PresenceState.ONLINE) | |||
def test_set_presence_from_syncing_keeps_status(self): | |||
def test_set_presence_from_syncing_keeps_status(self) -> None: | |||
"""Test that presence set by syncing retains status message""" | |||
user_id = "@test:server" | |||
status_msg = "I'm here!" | |||
@@ -726,7 +732,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
}, | |||
} | |||
) | |||
def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool): | |||
def test_set_presence_from_syncing_keeps_busy( | |||
self, test_with_workers: bool | |||
) -> None: | |||
"""Test that presence set by syncing doesn't affect busy status | |||
Args: | |||
@@ -767,7 +775,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
def _set_presencestate_with_status_msg( | |||
self, user_id: str, state: str, status_msg: Optional[str] | |||
): | |||
) -> None: | |||
"""Set a PresenceState and status_msg and check the result. | |||
Args: | |||
@@ -790,14 +798,14 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): | |||
class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.presence_handler = hs.get_presence_handler() | |||
self.clock = hs.get_clock() | |||
self.instance_name = hs.get_instance_name() | |||
self.queue = self.presence_handler.get_federation_queue() | |||
def test_send_and_get(self): | |||
def test_send_and_get(self) -> None: | |||
state1 = UserPresenceState.default("@user1:test") | |||
state2 = UserPresenceState.default("@user2:test") | |||
state3 = UserPresenceState.default("@user3:test") | |||
@@ -834,7 +842,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): | |||
self.assertFalse(limited) | |||
self.assertCountEqual(rows, []) | |||
def test_send_and_get_split(self): | |||
def test_send_and_get_split(self) -> None: | |||
state1 = UserPresenceState.default("@user1:test") | |||
state2 = UserPresenceState.default("@user2:test") | |||
state3 = UserPresenceState.default("@user3:test") | |||
@@ -877,7 +885,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): | |||
self.assertCountEqual(rows, expected_rows) | |||
def test_clear_queue_all(self): | |||
def test_clear_queue_all(self) -> None: | |||
state1 = UserPresenceState.default("@user1:test") | |||
state2 = UserPresenceState.default("@user2:test") | |||
state3 = UserPresenceState.default("@user3:test") | |||
@@ -921,7 +929,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): | |||
self.assertCountEqual(rows, expected_rows) | |||
def test_partially_clear_queue(self): | |||
def test_partially_clear_queue(self) -> None: | |||
state1 = UserPresenceState.default("@user1:test") | |||
state2 = UserPresenceState.default("@user2:test") | |||
state3 = UserPresenceState.default("@user3:test") | |||
@@ -982,7 +990,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): | |||
servlets = [room.register_servlets] | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
hs = self.setup_test_homeserver( | |||
"server", | |||
federation_http_client=None, | |||
@@ -990,14 +998,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): | |||
) | |||
return hs | |||
def default_config(self): | |||
def default_config(self) -> JsonDict: | |||
config = super().default_config() | |||
# Enable federation sending on the main process. | |||
config["federation_sender_instances"] = None | |||
return config | |||
def prepare(self, reactor, clock, hs): | |||
self.federation_sender = hs.get_federation_sender() | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.federation_sender = cast(Mock, hs.get_federation_sender()) | |||
self.event_builder_factory = hs.get_event_builder_factory() | |||
self.federation_event_handler = hs.get_federation_event_handler() | |||
self.presence_handler = hs.get_presence_handler() | |||
@@ -1013,7 +1021,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): | |||
# random key to use. | |||
self.random_signing_key = generate_signing_key("ver") | |||
def test_remote_joins(self): | |||
def test_remote_joins(self) -> None: | |||
# We advance time to something that isn't 0, as we use 0 as a special | |||
# value. | |||
self.reactor.advance(1000000000000) | |||
@@ -1061,7 +1069,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): | |||
destinations={"server3"}, states=[expected_state] | |||
) | |||
def test_remote_gets_presence_when_local_user_joins(self): | |||
def test_remote_gets_presence_when_local_user_joins(self) -> None: | |||
# We advance time to something that isn't 0, as we use 0 as a special | |||
# value. | |||
self.reactor.advance(1000000000000) | |||
@@ -1110,7 +1118,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): | |||
destinations={"server2", "server3"}, states=[expected_state] | |||
) | |||
def _add_new_user(self, room_id, user_id): | |||
def _add_new_user(self, room_id: str, user_id: str) -> None: | |||
"""Add new user to the room by creating an event and poking the federation API.""" | |||
hostname = get_domain_from_id(user_id) | |||
@@ -332,7 +332,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
@unittest.override_config( | |||
{"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]} | |||
) | |||
def test_avatar_constraint_on_local_server_with_port(self): | |||
def test_avatar_constraint_on_local_server_with_port(self) -> None: | |||
"""Test that avatar metadata is correctly fetched when the media is on a local | |||
server and the server has an explicit port. | |||
@@ -376,7 +376,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc)) | |||
) | |||
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]): | |||
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None: | |||
"""Stores metadata about files in the database. | |||
Args: | |||
@@ -15,14 +15,18 @@ | |||
from copy import deepcopy | |||
from typing import List | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import EduTypes, ReceiptTypes | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict | |||
from synapse.util import Clock | |||
from tests import unittest | |||
class ReceiptsTestCase(unittest.HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.event_source = hs.get_event_sources().sources.receipt | |||
def test_filters_out_private_receipt(self) -> None: | |||
@@ -12,8 +12,11 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Collection, List, Optional, Tuple | |||
from unittest.mock import Mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.auth import Auth | |||
from synapse.api.constants import UserTypes | |||
from synapse.api.errors import ( | |||
@@ -22,8 +25,18 @@ from synapse.api.errors import ( | |||
ResourceLimitError, | |||
SynapseError, | |||
) | |||
from synapse.module_api import ModuleApi | |||
from synapse.server import HomeServer | |||
from synapse.spam_checker_api import RegistrationBehaviour | |||
from synapse.types import RoomAlias, RoomID, UserID, create_requester | |||
from synapse.types import ( | |||
JsonDict, | |||
Requester, | |||
RoomAlias, | |||
RoomID, | |||
UserID, | |||
create_requester, | |||
) | |||
from synapse.util import Clock | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import override_config | |||
@@ -33,94 +46,98 @@ from .. import unittest | |||
class TestSpamChecker: | |||
def __init__(self, config, api): | |||
def __init__(self, config: None, api: ModuleApi): | |||
api.register_spam_checker_callbacks( | |||
check_registration_for_spam=self.check_registration_for_spam, | |||
) | |||
@staticmethod | |||
def parse_config(config): | |||
return config | |||
def parse_config(config: JsonDict) -> None: | |||
return None | |||
async def check_registration_for_spam( | |||
self, | |||
email_threepid, | |||
username, | |||
request_info, | |||
auth_provider_id, | |||
): | |||
email_threepid: Optional[dict], | |||
username: Optional[str], | |||
request_info: Collection[Tuple[str, str]], | |||
auth_provider_id: Optional[str], | |||
) -> RegistrationBehaviour: | |||
pass | |||
class DenyAll(TestSpamChecker): | |||
async def check_registration_for_spam( | |||
self, | |||
email_threepid, | |||
username, | |||
request_info, | |||
auth_provider_id, | |||
): | |||
email_threepid: Optional[dict], | |||
username: Optional[str], | |||
request_info: Collection[Tuple[str, str]], | |||
auth_provider_id: Optional[str], | |||
) -> RegistrationBehaviour: | |||
return RegistrationBehaviour.DENY | |||
class BanAll(TestSpamChecker): | |||
async def check_registration_for_spam( | |||
self, | |||
email_threepid, | |||
username, | |||
request_info, | |||
auth_provider_id, | |||
): | |||
email_threepid: Optional[dict], | |||
username: Optional[str], | |||
request_info: Collection[Tuple[str, str]], | |||
auth_provider_id: Optional[str], | |||
) -> RegistrationBehaviour: | |||
return RegistrationBehaviour.SHADOW_BAN | |||
class BanBadIdPUser(TestSpamChecker): | |||
async def check_registration_for_spam( | |||
self, email_threepid, username, request_info, auth_provider_id=None | |||
): | |||
self, | |||
email_threepid: Optional[dict], | |||
username: Optional[str], | |||
request_info: Collection[Tuple[str, str]], | |||
auth_provider_id: Optional[str] = None, | |||
) -> RegistrationBehaviour: | |||
# Reject any user coming from CAS and whose username contains profanity | |||
if auth_provider_id == "cas" and "flimflob" in username: | |||
if auth_provider_id == "cas" and username and "flimflob" in username: | |||
return RegistrationBehaviour.DENY | |||
return RegistrationBehaviour.ALLOW | |||
class TestLegacyRegistrationSpamChecker: | |||
def __init__(self, config, api): | |||
def __init__(self, config: None, api: ModuleApi): | |||
pass | |||
async def check_registration_for_spam( | |||
self, | |||
email_threepid, | |||
username, | |||
request_info, | |||
): | |||
email_threepid: Optional[dict], | |||
username: Optional[str], | |||
request_info: Collection[Tuple[str, str]], | |||
) -> RegistrationBehaviour: | |||
pass | |||
class LegacyAllowAll(TestLegacyRegistrationSpamChecker): | |||
async def check_registration_for_spam( | |||
self, | |||
email_threepid, | |||
username, | |||
request_info, | |||
): | |||
email_threepid: Optional[dict], | |||
username: Optional[str], | |||
request_info: Collection[Tuple[str, str]], | |||
) -> RegistrationBehaviour: | |||
return RegistrationBehaviour.ALLOW | |||
class LegacyDenyAll(TestLegacyRegistrationSpamChecker): | |||
async def check_registration_for_spam( | |||
self, | |||
email_threepid, | |||
username, | |||
request_info, | |||
): | |||
email_threepid: Optional[dict], | |||
username: Optional[str], | |||
request_info: Collection[Tuple[str, str]], | |||
) -> RegistrationBehaviour: | |||
return RegistrationBehaviour.DENY | |||
class RegistrationTestCase(unittest.HomeserverTestCase): | |||
"""Tests the RegistrationHandler.""" | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
hs_config = self.default_config() | |||
# some of the tests rely on us having a user consent version | |||
@@ -145,7 +162,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
return hs | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.handler = self.hs.get_registration_handler() | |||
self.store = self.hs.get_datastores().main | |||
self.lots_of_users = 100 | |||
@@ -153,7 +170,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
self.requester = create_requester("@requester:test") | |||
def test_user_is_created_and_logged_in_if_doesnt_exist(self): | |||
def test_user_is_created_and_logged_in_if_doesnt_exist(self) -> None: | |||
frank = UserID.from_string("@frank:test") | |||
user_id = frank.to_string() | |||
requester = create_requester(user_id) | |||
@@ -164,7 +181,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
self.assertIsInstance(result_token, str) | |||
self.assertGreater(len(result_token), 20) | |||
def test_if_user_exists(self): | |||
def test_if_user_exists(self) -> None: | |||
store = self.hs.get_datastores().main | |||
frank = UserID.from_string("@frank:test") | |||
self.get_success( | |||
@@ -180,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
self.assertTrue(result_token is not None) | |||
@override_config({"limit_usage_by_mau": False}) | |||
def test_mau_limits_when_disabled(self): | |||
def test_mau_limits_when_disabled(self) -> None: | |||
# Ensure does not throw exception | |||
self.get_success(self.get_or_create_user(self.requester, "a", "display_name")) | |||
@override_config({"limit_usage_by_mau": True}) | |||
def test_get_or_create_user_mau_not_blocked(self): | |||
def test_get_or_create_user_mau_not_blocked(self) -> None: | |||
self.store.count_monthly_users = Mock( | |||
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) | |||
) | |||
@@ -193,7 +210,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
self.get_success(self.get_or_create_user(self.requester, "c", "User")) | |||
@override_config({"limit_usage_by_mau": True}) | |||
def test_get_or_create_user_mau_blocked(self): | |||
def test_get_or_create_user_mau_blocked(self) -> None: | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.lots_of_users) | |||
) | |||
@@ -211,7 +228,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
) | |||
@override_config({"limit_usage_by_mau": True}) | |||
def test_register_mau_blocked(self): | |||
def test_register_mau_blocked(self) -> None: | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.lots_of_users) | |||
) | |||
@@ -229,7 +246,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
@override_config( | |||
{"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False} | |||
) | |||
def test_auto_join_rooms_for_guests(self): | |||
def test_auto_join_rooms_for_guests(self) -> None: | |||
user_id = self.get_success( | |||
self.handler.register_user(localpart="jeff", make_guest=True), | |||
) | |||
@@ -237,7 +254,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(len(rooms), 0) | |||
@override_config({"auto_join_rooms": ["#room:test"]}) | |||
def test_auto_create_auto_join_rooms(self): | |||
def test_auto_create_auto_join_rooms(self) -> None: | |||
room_alias_str = "#room:test" | |||
user_id = self.get_success(self.handler.register_user(localpart="jeff")) | |||
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) | |||
@@ -249,7 +266,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(len(rooms), 1) | |||
@override_config({"auto_join_rooms": []}) | |||
def test_auto_create_auto_join_rooms_with_no_rooms(self): | |||
def test_auto_create_auto_join_rooms_with_no_rooms(self) -> None: | |||
frank = UserID.from_string("@frank:test") | |||
user_id = self.get_success(self.handler.register_user(frank.localpart)) | |||
self.assertEqual(user_id, frank.to_string()) | |||
@@ -257,7 +274,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(len(rooms), 0) | |||
@override_config({"auto_join_rooms": ["#room:another"]}) | |||
def test_auto_create_auto_join_where_room_is_another_domain(self): | |||
def test_auto_create_auto_join_where_room_is_another_domain(self) -> None: | |||
frank = UserID.from_string("@frank:test") | |||
user_id = self.get_success(self.handler.register_user(frank.localpart)) | |||
self.assertEqual(user_id, frank.to_string()) | |||
@@ -267,13 +284,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
@override_config( | |||
{"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False} | |||
) | |||
def test_auto_create_auto_join_where_auto_create_is_false(self): | |||
def test_auto_create_auto_join_where_auto_create_is_false(self) -> None: | |||
user_id = self.get_success(self.handler.register_user(localpart="jeff")) | |||
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) | |||
self.assertEqual(len(rooms), 0) | |||
@override_config({"auto_join_rooms": ["#room:test"]}) | |||
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self): | |||
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None: | |||
room_alias_str = "#room:test" | |||
self.store.is_real_user = Mock(return_value=make_awaitable(False)) | |||
user_id = self.get_success(self.handler.register_user(localpart="support")) | |||
@@ -284,7 +301,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
self.get_failure(directory_handler.get_association(room_alias), SynapseError) | |||
@override_config({"auto_join_rooms": ["#room:test"]}) | |||
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): | |||
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: | |||
room_alias_str = "#room:test" | |||
self.store.count_real_users = Mock(return_value=make_awaitable(1)) | |||
@@ -299,7 +316,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(len(rooms), 1) | |||
@override_config({"auto_join_rooms": ["#room:test"]}) | |||
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self): | |||
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( | |||
self, | |||
) -> None: | |||
self.store.count_real_users = Mock(return_value=make_awaitable(2)) | |||
self.store.is_real_user = Mock(return_value=make_awaitable(True)) | |||
user_id = self.get_success(self.handler.register_user(localpart="real")) | |||
@@ -312,7 +331,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
"autocreate_auto_join_rooms_federated": False, | |||
} | |||
) | |||
def test_auto_create_auto_join_rooms_federated(self): | |||
def test_auto_create_auto_join_rooms_federated(self) -> None: | |||
""" | |||
Auto-created rooms that are private require an invite to go to the user | |||
(instead of directly joining it). | |||
@@ -339,7 +358,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
@override_config( | |||
{"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"} | |||
) | |||
def test_auto_join_mxid_localpart(self): | |||
def test_auto_join_mxid_localpart(self) -> None: | |||
""" | |||
Ensure the user still needs up in the room created by a different user. | |||
""" | |||
@@ -376,7 +395,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
"auto_join_mxid_localpart": "support", | |||
} | |||
) | |||
def test_auto_create_auto_join_room_preset(self): | |||
def test_auto_create_auto_join_room_preset(self) -> None: | |||
""" | |||
Auto-created rooms that are private require an invite to go to the user | |||
(instead of directly joining it). | |||
@@ -416,7 +435,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
"auto_join_mxid_localpart": "support", | |||
} | |||
) | |||
def test_auto_create_auto_join_room_preset_guest(self): | |||
def test_auto_create_auto_join_room_preset_guest(self) -> None: | |||
""" | |||
Auto-created rooms that are private require an invite to go to the user | |||
(instead of directly joining it). | |||
@@ -454,7 +473,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
"auto_join_mxid_localpart": "support", | |||
} | |||
) | |||
def test_auto_create_auto_join_room_preset_invalid_permissions(self): | |||
def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None: | |||
""" | |||
Auto-created rooms that are private require an invite, check that | |||
registration doesn't completely break if the inviter doesn't have proper | |||
@@ -525,7 +544,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
"auto_join_rooms": ["#room:test"], | |||
}, | |||
) | |||
def test_auto_create_auto_join_where_no_consent(self): | |||
def test_auto_create_auto_join_where_no_consent(self) -> None: | |||
"""Test to ensure that the first user is not auto-joined to a room if | |||
they have not given general consent. | |||
""" | |||
@@ -550,19 +569,19 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) | |||
self.assertEqual(len(rooms), 1) | |||
def test_register_support_user(self): | |||
def test_register_support_user(self) -> None: | |||
user_id = self.get_success( | |||
self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT) | |||
) | |||
d = self.store.is_support_user(user_id) | |||
self.assertTrue(self.get_success(d)) | |||
def test_register_not_support_user(self): | |||
def test_register_not_support_user(self) -> None: | |||
user_id = self.get_success(self.handler.register_user(localpart="user")) | |||
d = self.store.is_support_user(user_id) | |||
self.assertFalse(self.get_success(d)) | |||
def test_invalid_user_id_length(self): | |||
def test_invalid_user_id_length(self) -> None: | |||
invalid_user_id = "x" * 256 | |||
self.get_failure( | |||
self.handler.register_user(localpart=invalid_user_id), SynapseError | |||
@@ -577,7 +596,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
] | |||
} | |||
) | |||
def test_spam_checker_deny(self): | |||
def test_spam_checker_deny(self) -> None: | |||
"""A spam checker can deny registration, which results in an error.""" | |||
self.get_failure(self.handler.register_user(localpart="user"), SynapseError) | |||
@@ -590,7 +609,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
] | |||
} | |||
) | |||
def test_spam_checker_legacy_allow(self): | |||
def test_spam_checker_legacy_allow(self) -> None: | |||
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the | |||
check_registration_for_spam callback is correctly called. | |||
@@ -610,7 +629,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
] | |||
} | |||
) | |||
def test_spam_checker_legacy_deny(self): | |||
def test_spam_checker_legacy_deny(self) -> None: | |||
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the | |||
check_registration_for_spam callback is correctly called. | |||
@@ -630,7 +649,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
] | |||
} | |||
) | |||
def test_spam_checker_shadow_ban(self): | |||
def test_spam_checker_shadow_ban(self) -> None: | |||
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" | |||
user_id = self.get_success(self.handler.register_user(localpart="user")) | |||
@@ -660,7 +679,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
] | |||
} | |||
) | |||
def test_spam_checker_receives_sso_type(self): | |||
def test_spam_checker_receives_sso_type(self) -> None: | |||
"""Test rejecting registration based on SSO type""" | |||
f = self.get_failure( | |||
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"), | |||
@@ -678,8 +697,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
) | |||
async def get_or_create_user( | |||
self, requester, localpart, displayname, password_hash=None | |||
): | |||
self, | |||
requester: Requester, | |||
localpart: str, | |||
displayname: Optional[str], | |||
password_hash: Optional[str] = None, | |||
) -> Tuple[str, str]: | |||
"""Creates a new user if the user does not exist, | |||
else revokes all previous access tokens and generates a new one. | |||
@@ -734,13 +757,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): | |||
"""Tests auto-join on remote rooms.""" | |||
def make_homeserver(self, reactor, clock): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
self.room_id = "!roomid:remotetest" | |||
async def update_membership(*args, **kwargs): | |||
async def update_membership(*args: Any, **kwargs: Any) -> None: | |||
pass | |||
async def lookup_room_alias(*args, **kwargs): | |||
async def lookup_room_alias( | |||
*args: Any, **kwargs: Any | |||
) -> Tuple[RoomID, List[str]]: | |||
return RoomID.from_string(self.room_id), ["remotetest"] | |||
self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"]) | |||
@@ -750,12 +775,12 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): | |||
hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler) | |||
return hs | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.handler = self.hs.get_registration_handler() | |||
self.store = self.hs.get_datastores().main | |||
@override_config({"auto_join_rooms": ["#room:remotetest"]}) | |||
def test_auto_create_auto_join_remote_room(self): | |||
def test_auto_create_auto_join_remote_room(self) -> None: | |||
"""Tests that we don't attempt to create remote rooms, and that we don't attempt | |||
to invite ourselves to rooms we're not in.""" | |||
@@ -14,7 +14,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase): | |||
] | |||
@override_config({"encryption_enabled_by_default_for_room_type": "all"}) | |||
def test_encrypted_by_default_config_option_all(self): | |||
def test_encrypted_by_default_config_option_all(self) -> None: | |||
"""Tests that invite-only and non-invite-only rooms have encryption enabled by | |||
default when the config option encryption_enabled_by_default_for_room_type is "all". | |||
""" | |||
@@ -45,7 +45,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) | |||
@override_config({"encryption_enabled_by_default_for_room_type": "invite"}) | |||
def test_encrypted_by_default_config_option_invite(self): | |||
def test_encrypted_by_default_config_option_invite(self) -> None: | |||
"""Tests that only new, invite-only rooms have encryption enabled by default when | |||
the config option encryption_enabled_by_default_for_room_type is "invite". | |||
""" | |||
@@ -76,7 +76,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase): | |||
) | |||
@override_config({"encryption_enabled_by_default_for_room_type": "off"}) | |||
def test_encrypted_by_default_config_option_off(self): | |||
def test_encrypted_by_default_config_option_off(self) -> None: | |||
"""Tests that neither new invite-only nor non-invite-only rooms have encryption | |||
enabled by default when the config option | |||
encryption_enabled_by_default_for_room_type is "off". | |||
@@ -11,10 +11,11 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Iterable, List, Optional, Tuple | |||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple | |||
from unittest import mock | |||
from twisted.internet.defer import ensureDeferred | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import ( | |||
EventContentFields, | |||
@@ -34,11 +35,14 @@ from synapse.rest import admin | |||
from synapse.rest.client import login, room | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict, UserID, create_requester | |||
from synapse.util import Clock | |||
from tests import unittest | |||
def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0): | |||
def _create_event( | |||
room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0 | |||
) -> mock.Mock: | |||
result = mock.Mock(name=room_id) | |||
result.room_id = room_id | |||
result.content = {} | |||
@@ -48,40 +52,40 @@ def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: i | |||
return result | |||
def _order(*events): | |||
def _order(*events: mock.Mock) -> List[mock.Mock]: | |||
return sorted(events, key=_child_events_comparison_key) | |||
class TestSpaceSummarySort(unittest.TestCase): | |||
def test_no_order_last(self): | |||
def test_no_order_last(self) -> None: | |||
"""An event with no ordering is placed behind those with an ordering.""" | |||
ev1 = _create_event("!abc:test") | |||
ev2 = _create_event("!xyz:test", "xyz") | |||
self.assertEqual([ev2, ev1], _order(ev1, ev2)) | |||
def test_order(self): | |||
def test_order(self) -> None: | |||
"""The ordering should be used.""" | |||
ev1 = _create_event("!abc:test", "xyz") | |||
ev2 = _create_event("!xyz:test", "abc") | |||
self.assertEqual([ev2, ev1], _order(ev1, ev2)) | |||
def test_order_origin_server_ts(self): | |||
def test_order_origin_server_ts(self) -> None: | |||
"""Origin server is a tie-breaker for ordering.""" | |||
ev1 = _create_event("!abc:test", origin_server_ts=10) | |||
ev2 = _create_event("!xyz:test", origin_server_ts=30) | |||
self.assertEqual([ev1, ev2], _order(ev1, ev2)) | |||
def test_order_room_id(self): | |||
def test_order_room_id(self) -> None: | |||
"""Room ID is a final tie-breaker for ordering.""" | |||
ev1 = _create_event("!abc:test") | |||
ev2 = _create_event("!xyz:test") | |||
self.assertEqual([ev1, ev2], _order(ev1, ev2)) | |||
def test_invalid_ordering_type(self): | |||
def test_invalid_ordering_type(self) -> None: | |||
"""Invalid orderings are considered the same as missing.""" | |||
ev1 = _create_event("!abc:test", 1) | |||
ev2 = _create_event("!xyz:test", "xyz") | |||
@@ -97,7 +101,7 @@ class TestSpaceSummarySort(unittest.TestCase): | |||
ev1 = _create_event("!abc:test", True) | |||
self.assertEqual([ev2, ev1], _order(ev1, ev2)) | |||
def test_invalid_ordering_value(self): | |||
def test_invalid_ordering_value(self) -> None: | |||
"""Invalid orderings are considered the same as missing.""" | |||
ev1 = _create_event("!abc:test", "foo\n") | |||
ev2 = _create_event("!xyz:test", "xyz") | |||
@@ -115,7 +119,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs: HomeServer): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.hs = hs | |||
self.handler = self.hs.get_room_summary_handler() | |||
@@ -223,7 +227,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) | |||
) | |||
def test_simple_space(self): | |||
def test_simple_space(self) -> None: | |||
"""Test a simple space with a single room.""" | |||
# The result should have the space and the room in it, along with a link | |||
# from space -> room. | |||
@@ -234,7 +238,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
) | |||
self._assert_hierarchy(result, expected) | |||
def test_large_space(self): | |||
def test_large_space(self) -> None: | |||
"""Test a space with a large number of rooms.""" | |||
rooms = [self.room] | |||
# Make at least 51 rooms that are part of the space. | |||
@@ -260,7 +264,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
result["rooms"] += result2["rooms"] | |||
self._assert_hierarchy(result, expected) | |||
def test_visibility(self): | |||
def test_visibility(self) -> None: | |||
"""A user not in a space cannot inspect it.""" | |||
user2 = self.register_user("user2", "pass") | |||
token2 = self.login("user2", "pass") | |||
@@ -380,7 +384,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
self._assert_hierarchy(result2, [(self.space, [self.room])]) | |||
def _create_room_with_join_rule( | |||
self, join_rule: str, room_version: Optional[str] = None, **extra_content | |||
self, join_rule: str, room_version: Optional[str] = None, **extra_content: Any | |||
) -> str: | |||
"""Create a room with the given join rule and add it to the space.""" | |||
room_id = self.helper.create_room_as( | |||
@@ -403,7 +407,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
self._add_child(self.space, room_id, self.token) | |||
return room_id | |||
def test_filtering(self): | |||
def test_filtering(self) -> None: | |||
""" | |||
Rooms should be properly filtered to only include rooms the user has access to. | |||
""" | |||
@@ -476,7 +480,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
) | |||
self._assert_hierarchy(result, expected) | |||
def test_complex_space(self): | |||
def test_complex_space(self) -> None: | |||
""" | |||
Create a "complex" space to see how it handles things like loops and subspaces. | |||
""" | |||
@@ -516,7 +520,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
) | |||
self._assert_hierarchy(result, expected) | |||
def test_pagination(self): | |||
def test_pagination(self) -> None: | |||
"""Test simple pagination works.""" | |||
room_ids = [] | |||
for i in range(1, 10): | |||
@@ -553,7 +557,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
self._assert_hierarchy(result, expected) | |||
self.assertNotIn("next_batch", result) | |||
def test_invalid_pagination_token(self): | |||
def test_invalid_pagination_token(self) -> None: | |||
"""An invalid pagination token, or changing other parameters, shoudl be rejected.""" | |||
room_ids = [] | |||
for i in range(1, 10): | |||
@@ -604,7 +608,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
SynapseError, | |||
) | |||
def test_max_depth(self): | |||
def test_max_depth(self) -> None: | |||
"""Create a deep tree to test the max depth against.""" | |||
spaces = [self.space] | |||
rooms = [self.room] | |||
@@ -659,7 +663,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
] | |||
self._assert_hierarchy(result, expected) | |||
def test_unknown_room_version(self): | |||
def test_unknown_room_version(self) -> None: | |||
""" | |||
If a room with an unknown room version is encountered it should not cause | |||
the entire summary to skip. | |||
@@ -685,7 +689,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
) | |||
self._assert_hierarchy(result, expected) | |||
def test_fed_complex(self): | |||
def test_fed_complex(self) -> None: | |||
""" | |||
Return data over federation and ensure that it is handled properly. | |||
""" | |||
@@ -722,7 +726,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
"world_readable": True, | |||
} | |||
async def summarize_remote_room_hierarchy(_self, room, suggested_only): | |||
async def summarize_remote_room_hierarchy( | |||
_self: Any, room: Any, suggested_only: bool | |||
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: | |||
return requested_room_entry, {subroom: child_room}, set() | |||
# Add a room to the space which is on another server. | |||
@@ -744,7 +750,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
) | |||
self._assert_hierarchy(result, expected) | |||
def test_fed_filtering(self): | |||
def test_fed_filtering(self) -> None: | |||
""" | |||
Rooms returned over federation should be properly filtered to only include | |||
rooms the user has access to. | |||
@@ -853,7 +859,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
], | |||
) | |||
async def summarize_remote_room_hierarchy(_self, room, suggested_only): | |||
async def summarize_remote_room_hierarchy( | |||
_self: Any, room: Any, suggested_only: bool | |||
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: | |||
return subspace_room_entry, dict(children_rooms), set() | |||
# Add a room to the space which is on another server. | |||
@@ -892,7 +900,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
) | |||
self._assert_hierarchy(result, expected) | |||
def test_fed_invited(self): | |||
def test_fed_invited(self) -> None: | |||
""" | |||
A room which the user was invited to should be included in the response. | |||
@@ -915,7 +923,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
}, | |||
) | |||
async def summarize_remote_room_hierarchy(_self, room, suggested_only): | |||
async def summarize_remote_room_hierarchy( | |||
_self: Any, room: Any, suggested_only: bool | |||
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: | |||
return fed_room_entry, {}, set() | |||
# Add a room to the space which is on another server. | |||
@@ -936,7 +946,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): | |||
) | |||
self._assert_hierarchy(result, expected) | |||
def test_fed_caching(self): | |||
def test_fed_caching(self) -> None: | |||
""" | |||
Federation `/hierarchy` responses should be cached. | |||
""" | |||
@@ -1023,7 +1033,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs: HomeServer): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.hs = hs | |||
self.handler = self.hs.get_room_summary_handler() | |||
@@ -1040,12 +1050,12 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): | |||
tok=self.token, | |||
) | |||
def test_own_room(self): | |||
def test_own_room(self) -> None: | |||
"""Test a simple room created by the requester.""" | |||
result = self.get_success(self.handler.get_room_summary(self.user, self.room)) | |||
self.assertEqual(result.get("room_id"), self.room) | |||
def test_visibility(self): | |||
def test_visibility(self) -> None: | |||
"""A user not in a private room cannot get its summary.""" | |||
user2 = self.register_user("user2", "pass") | |||
token2 = self.login("user2", "pass") | |||
@@ -1093,7 +1103,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): | |||
result = self.get_success(self.handler.get_room_summary(user2, self.room)) | |||
self.assertEqual(result.get("room_id"), self.room) | |||
def test_fed(self): | |||
def test_fed(self) -> None: | |||
""" | |||
Return data over federation and ensure that it is handled properly. | |||
""" | |||
@@ -1105,7 +1115,9 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): | |||
{"room_id": fed_room, "world_readable": True}, | |||
) | |||
async def summarize_remote_room_hierarchy(_self, room, suggested_only): | |||
async def summarize_remote_room_hierarchy( | |||
_self: Any, room: Any, suggested_only: bool | |||
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]: | |||
return requested_room_entry, {}, set() | |||
with mock.patch( | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Dict, Optional | |||
from typing import Any, Dict, Optional, Set, Tuple | |||
from unittest.mock import Mock | |||
import attr | |||
@@ -20,7 +20,9 @@ import attr | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.errors import RedirectException | |||
from synapse.module_api import ModuleApi | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict | |||
from synapse.util import Clock | |||
from tests.test_utils import simple_async_mock | |||
@@ -29,6 +31,7 @@ from tests.unittest import HomeserverTestCase, override_config | |||
# Check if we have the dependencies to run the tests. | |||
try: | |||
import saml2.config | |||
import saml2.response | |||
from saml2.sigver import SigverError | |||
has_saml2 = True | |||
@@ -56,31 +59,39 @@ class FakeAuthnResponse: | |||
class TestMappingProvider: | |||
def __init__(self, config, module): | |||
def __init__(self, config: None, module: ModuleApi): | |||
pass | |||
@staticmethod | |||
def parse_config(config): | |||
return | |||
def parse_config(config: JsonDict) -> None: | |||
return None | |||
@staticmethod | |||
def get_saml_attributes(config): | |||
def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]: | |||
return {"uid"}, {"displayName"} | |||
def get_remote_user_id(self, saml_response, client_redirect_url): | |||
def get_remote_user_id( | |||
self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str | |||
) -> str: | |||
return saml_response.ava["uid"] | |||
def saml_response_to_user_attributes( | |||
self, saml_response, failures, client_redirect_url | |||
): | |||
self, | |||
saml_response: "saml2.response.AuthnResponse", | |||
failures: int, | |||
client_redirect_url: str, | |||
) -> dict: | |||
localpart = saml_response.ava["username"] + (str(failures) if failures else "") | |||
return {"mxid_localpart": localpart, "displayname": None} | |||
class TestRedirectMappingProvider(TestMappingProvider): | |||
def saml_response_to_user_attributes( | |||
self, saml_response, failures, client_redirect_url | |||
): | |||
self, | |||
saml_response: "saml2.response.AuthnResponse", | |||
failures: int, | |||
client_redirect_url: str, | |||
) -> dict: | |||
raise RedirectException(b"https://custom-saml-redirect/") | |||
@@ -347,7 +358,7 @@ class SamlHandlerTestCase(HomeserverTestCase): | |||
) | |||
def _mock_request(): | |||
def _mock_request() -> Mock: | |||
"""Returns a mock which will stand in as a SynapseRequest""" | |||
mock = Mock( | |||
spec=[ | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
from typing import List, Tuple | |||
from typing import Callable, List, Tuple | |||
from zope.interface import implementer | |||
@@ -28,20 +28,27 @@ from tests.unittest import HomeserverTestCase, override_config | |||
@implementer(interfaces.IMessageDelivery) | |||
class _DummyMessageDelivery: | |||
def __init__(self): | |||
def __init__(self) -> None: | |||
# (recipient, message) tuples | |||
self.messages: List[Tuple[smtp.Address, bytes]] = [] | |||
def receivedHeader(self, helo, origin, recipients): | |||
def receivedHeader( | |||
self, | |||
helo: Tuple[bytes, bytes], | |||
origin: smtp.Address, | |||
recipients: List[smtp.User], | |||
) -> None: | |||
return None | |||
def validateFrom(self, helo, origin): | |||
def validateFrom( | |||
self, helo: Tuple[bytes, bytes], origin: smtp.Address | |||
) -> smtp.Address: | |||
return origin | |||
def record_message(self, recipient: smtp.Address, message: bytes): | |||
def record_message(self, recipient: smtp.Address, message: bytes) -> None: | |||
self.messages.append((recipient, message)) | |||
def validateTo(self, user: smtp.User): | |||
def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]: | |||
return lambda: _DummyMessage(self, user) | |||
@@ -56,20 +63,20 @@ class _DummyMessage: | |||
self._user = user | |||
self._buffer: List[bytes] = [] | |||
def lineReceived(self, line): | |||
def lineReceived(self, line: bytes) -> None: | |||
self._buffer.append(line) | |||
def eomReceived(self): | |||
def eomReceived(self) -> "defer.Deferred[bytes]": | |||
message = b"\n".join(self._buffer) + b"\n" | |||
self._delivery.record_message(self._user.dest, message) | |||
return defer.succeed(b"saved") | |||
def connectionLost(self): | |||
def connectionLost(self) -> None: | |||
pass | |||
class SendEmailHandlerTestCase(HomeserverTestCase): | |||
def test_send_email(self): | |||
def test_send_email(self) -> None: | |||
"""Happy-path test that we can send email to a non-TLS server.""" | |||
h = self.hs.get_send_email_handler() | |||
d = ensureDeferred( | |||
@@ -119,7 +126,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase): | |||
}, | |||
} | |||
) | |||
def test_send_email_force_tls(self): | |||
def test_send_email_force_tls(self) -> None: | |||
"""Happy-path test that we can send email to an Implicit TLS server.""" | |||
h = self.hs.get_send_email_handler() | |||
d = ensureDeferred( | |||
@@ -12,9 +12,15 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Dict, List, Optional | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.rest import admin | |||
from synapse.rest.client import login, room | |||
from synapse.server import HomeServer | |||
from synapse.storage.databases.main import stats | |||
from synapse.util import Clock | |||
from tests import unittest | |||
@@ -32,11 +38,11 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
login.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.store = hs.get_datastores().main | |||
self.handler = self.hs.get_stats_handler() | |||
def _add_background_updates(self): | |||
def _add_background_updates(self) -> None: | |||
""" | |||
Add the background updates we need to run. | |||
""" | |||
@@ -63,12 +69,14 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
) | |||
) | |||
async def get_all_room_state(self): | |||
async def get_all_room_state(self) -> List[Dict[str, Any]]: | |||
return await self.store.db_pool.simple_select_list( | |||
"room_stats_state", None, retcols=("name", "topic", "canonical_alias") | |||
) | |||
def _get_current_stats(self, stats_type, stat_id): | |||
def _get_current_stats( | |||
self, stats_type: str, stat_id: str | |||
) -> Optional[Dict[str, Any]]: | |||
table, id_col = stats.TYPE_TO_TABLE[stats_type] | |||
cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) | |||
@@ -82,13 +90,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
) | |||
) | |||
def _perform_background_initial_update(self): | |||
def _perform_background_initial_update(self) -> None: | |||
# Do the initial population of the stats via the background update | |||
self._add_background_updates() | |||
self.wait_for_background_updates() | |||
def test_initial_room(self): | |||
def test_initial_room(self) -> None: | |||
""" | |||
The background updates will build the table from scratch. | |||
""" | |||
@@ -125,7 +133,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
self.assertEqual(len(r), 1) | |||
self.assertEqual(r[0]["topic"], "foo") | |||
def test_create_user(self): | |||
def test_create_user(self) -> None: | |||
""" | |||
When we create a user, it should have statistics already ready. | |||
""" | |||
@@ -134,12 +142,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
u1stats = self._get_current_stats("user", u1) | |||
self.assertIsNotNone(u1stats) | |||
assert u1stats is not None | |||
# not in any rooms by default | |||
self.assertEqual(u1stats["joined_rooms"], 0) | |||
def test_create_room(self): | |||
def test_create_room(self) -> None: | |||
""" | |||
When we create a room, it should have statistics already ready. | |||
""" | |||
@@ -153,8 +161,8 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False) | |||
r2stats = self._get_current_stats("room", r2) | |||
self.assertIsNotNone(r1stats) | |||
self.assertIsNotNone(r2stats) | |||
assert r1stats is not None | |||
assert r2stats is not None | |||
self.assertEqual( | |||
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM | |||
@@ -171,7 +179,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
self.assertEqual(r2stats["invited_members"], 0) | |||
self.assertEqual(r2stats["banned_members"], 0) | |||
def test_updating_profile_information_does_not_increase_joined_members_count(self): | |||
def test_updating_profile_information_does_not_increase_joined_members_count( | |||
self, | |||
) -> None: | |||
""" | |||
Check that the joined_members count does not increase when a user changes their | |||
profile information (which is done by sending another join membership event into | |||
@@ -186,6 +196,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
# Get the current room stats | |||
r1stats_ante = self._get_current_stats("room", r1) | |||
assert r1stats_ante is not None | |||
# Send a profile update into the room | |||
new_profile = {"displayname": "bob"} | |||
@@ -195,6 +206,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
# Get the new room stats | |||
r1stats_post = self._get_current_stats("room", r1) | |||
assert r1stats_post is not None | |||
# Ensure that the user count did not changed | |||
self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"]) | |||
@@ -202,7 +214,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"] | |||
) | |||
def test_send_state_event_nonoverwriting(self): | |||
def test_send_state_event_nonoverwriting(self) -> None: | |||
""" | |||
When we send a non-overwriting state event, it increments current_state_events | |||
""" | |||
@@ -218,19 +230,21 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
) | |||
r1stats_ante = self._get_current_stats("room", r1) | |||
assert r1stats_ante is not None | |||
self.helper.send_state( | |||
r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy" | |||
) | |||
r1stats_post = self._get_current_stats("room", r1) | |||
assert r1stats_post is not None | |||
self.assertEqual( | |||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], | |||
1, | |||
) | |||
def test_join_first_time(self): | |||
def test_join_first_time(self) -> None: | |||
""" | |||
When a user joins a room for the first time, current_state_events and | |||
joined_members should increase by exactly 1. | |||
@@ -246,10 +260,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
u2token = self.login("u2", "pass") | |||
r1stats_ante = self._get_current_stats("room", r1) | |||
assert r1stats_ante is not None | |||
self.helper.join(r1, u2, tok=u2token) | |||
r1stats_post = self._get_current_stats("room", r1) | |||
assert r1stats_post is not None | |||
self.assertEqual( | |||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], | |||
@@ -259,7 +275,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1 | |||
) | |||
def test_join_after_leave(self): | |||
def test_join_after_leave(self) -> None: | |||
""" | |||
When a user joins a room after being previously left, | |||
joined_members should increase by exactly 1. | |||
@@ -280,10 +296,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
self.helper.leave(r1, u2, tok=u2token) | |||
r1stats_ante = self._get_current_stats("room", r1) | |||
assert r1stats_ante is not None | |||
self.helper.join(r1, u2, tok=u2token) | |||
r1stats_post = self._get_current_stats("room", r1) | |||
assert r1stats_post is not None | |||
self.assertEqual( | |||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], | |||
@@ -296,7 +314,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
r1stats_post["left_members"] - r1stats_ante["left_members"], -1 | |||
) | |||
def test_invited(self): | |||
def test_invited(self) -> None: | |||
""" | |||
When a user invites another user, current_state_events and | |||
invited_members should increase by exactly 1. | |||
@@ -311,10 +329,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
u2 = self.register_user("u2", "pass") | |||
r1stats_ante = self._get_current_stats("room", r1) | |||
assert r1stats_ante is not None | |||
self.helper.invite(r1, u1, u2, tok=u1token) | |||
r1stats_post = self._get_current_stats("room", r1) | |||
assert r1stats_post is not None | |||
self.assertEqual( | |||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], | |||
@@ -324,7 +344,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1 | |||
) | |||
def test_join_after_invite(self): | |||
def test_join_after_invite(self) -> None: | |||
""" | |||
When a user joins a room after being invited and | |||
joined_members should increase by exactly 1. | |||
@@ -344,10 +364,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
self.helper.invite(r1, u1, u2, tok=u1token) | |||
r1stats_ante = self._get_current_stats("room", r1) | |||
assert r1stats_ante is not None | |||
self.helper.join(r1, u2, tok=u2token) | |||
r1stats_post = self._get_current_stats("room", r1) | |||
assert r1stats_post is not None | |||
self.assertEqual( | |||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], | |||
@@ -360,7 +382,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1 | |||
) | |||
def test_left(self): | |||
def test_left(self) -> None: | |||
""" | |||
When a user leaves a room after joining and | |||
left_members should increase by exactly 1. | |||
@@ -380,10 +402,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
self.helper.join(r1, u2, tok=u2token) | |||
r1stats_ante = self._get_current_stats("room", r1) | |||
assert r1stats_ante is not None | |||
self.helper.leave(r1, u2, tok=u2token) | |||
r1stats_post = self._get_current_stats("room", r1) | |||
assert r1stats_post is not None | |||
self.assertEqual( | |||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], | |||
@@ -396,7 +420,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 | |||
) | |||
def test_banned(self): | |||
def test_banned(self) -> None: | |||
""" | |||
When a user is banned from a room after joining and | |||
left_members should increase by exactly 1. | |||
@@ -416,10 +440,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
self.helper.join(r1, u2, tok=u2token) | |||
r1stats_ante = self._get_current_stats("room", r1) | |||
assert r1stats_ante is not None | |||
self.helper.change_membership(r1, u1, u2, "ban", tok=u1token) | |||
r1stats_post = self._get_current_stats("room", r1) | |||
assert r1stats_post is not None | |||
self.assertEqual( | |||
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"], | |||
@@ -432,7 +458,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1 | |||
) | |||
def test_initial_background_update(self): | |||
def test_initial_background_update(self) -> None: | |||
""" | |||
Test that statistics can be generated by the initial background update | |||
handler. | |||
@@ -462,6 +488,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
r1stats = self._get_current_stats("room", r1) | |||
u1stats = self._get_current_stats("user", u1) | |||
assert r1stats is not None | |||
assert u1stats is not None | |||
self.assertEqual(r1stats["joined_members"], 1) | |||
self.assertEqual( | |||
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM | |||
@@ -469,7 +498,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
self.assertEqual(u1stats["joined_rooms"], 1) | |||
def test_incomplete_stats(self): | |||
def test_incomplete_stats(self) -> None: | |||
""" | |||
This tests that we track incomplete statistics. | |||
@@ -533,8 +562,11 @@ class StatsRoomTests(unittest.HomeserverTestCase): | |||
self.wait_for_background_updates() | |||
r1stats_complete = self._get_current_stats("room", r1) | |||
assert r1stats_complete is not None | |||
u1stats_complete = self._get_current_stats("user", u1) | |||
assert u1stats_complete is not None | |||
u2stats_complete = self._get_current_stats("user", u2) | |||
assert u2stats_complete is not None | |||
# now we make our assertions | |||
@@ -14,6 +14,8 @@ | |||
from typing import Optional | |||
from unittest.mock import MagicMock, Mock, patch | |||
from twisted.test.proto_helpers import MemoryReactor | |||
from synapse.api.constants import EventTypes, JoinRules | |||
from synapse.api.errors import Codes, ResourceLimitError | |||
from synapse.api.filtering import Filtering | |||
@@ -23,6 +25,7 @@ from synapse.rest import admin | |||
from synapse.rest.client import knock, login, room | |||
from synapse.server import HomeServer | |||
from synapse.types import UserID, create_requester | |||
from synapse.util import Clock | |||
import tests.unittest | |||
import tests.utils | |||
@@ -39,7 +42,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): | |||
room.register_servlets, | |||
] | |||
def prepare(self, reactor, clock, hs: HomeServer): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.sync_handler = self.hs.get_sync_handler() | |||
self.store = self.hs.get_datastores().main | |||
@@ -47,7 +50,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): | |||
# modify its config instead of the hs' | |||
self.auth_blocking = self.hs.get_auth_blocking() | |||
def test_wait_for_sync_for_user_auth_blocking(self): | |||
def test_wait_for_sync_for_user_auth_blocking(self) -> None: | |||
user_id1 = "@user1:test" | |||
user_id2 = "@user2:test" | |||
sync_config = generate_sync_config(user_id1) | |||
@@ -82,7 +85,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): | |||
) | |||
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) | |||
def test_unknown_room_version(self): | |||
def test_unknown_room_version(self) -> None: | |||
""" | |||
A room with an unknown room version should not break sync (and should be excluded). | |||
""" | |||
@@ -186,7 +189,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): | |||
self.assertNotIn(invite_room, [r.room_id for r in result.invited]) | |||
self.assertNotIn(knock_room, [r.room_id for r in result.knocked]) | |||
def test_ban_wins_race_with_join(self): | |||
def test_ban_wins_race_with_join(self) -> None: | |||
"""Rooms shouldn't appear under "joined" if a join loses a race to a ban. | |||
A complicated edge case. Imagine the following scenario: | |||