Browse Source

Add missing type hints to tests.handlers. (#14680)

And do not allow untyped defs in tests.handlers.
tags/v1.75.0rc1
Patrick Cloke 1 year ago
committed by GitHub
parent
commit
652d1669c5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 527 additions and 378 deletions
  1. +1
    -0
      changelog.d/14680.misc
  2. +1
    -4
      mypy.ini
  3. +1
    -1
      synapse/handlers/auth.py
  4. +29
    -25
      tests/handlers/test_appservice.py
  5. +1
    -1
      tests/handlers/test_cas.py
  6. +15
    -12
      tests/handlers/test_directory.py
  7. +44
    -32
      tests/handlers/test_e2e_room_keys.py
  8. +1
    -1
      tests/handlers/test_federation.py
  9. +8
    -2
      tests/handlers/test_federation_event.py
  10. +16
    -10
      tests/handlers/test_message.py
  11. +31
    -17
      tests/handlers/test_oidc.py
  12. +74
    -70
      tests/handlers/test_password_providers.py
  13. +54
    -46
      tests/handlers/test_presence.py
  14. +2
    -2
      tests/handlers/test_profile.py
  15. +5
    -1
      tests/handlers/test_receipts.py
  16. +97
    -72
      tests/handlers/test_register.py
  17. +3
    -3
      tests/handlers/test_room.py
  18. +44
    -32
      tests/handlers/test_room_summary.py
  19. +22
    -11
      tests/handlers/test_saml.py
  20. +18
    -11
      tests/handlers/test_send_email.py
  21. +53
    -21
      tests/handlers/test_stats.py
  22. +7
    -4
      tests/handlers/test_sync.py

+ 1
- 0
changelog.d/14680.misc View File

@@ -0,0 +1 @@
Add missing type hints.

+ 1
- 4
mypy.ini View File

@@ -95,10 +95,7 @@ disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True

[mypy-tests.handlers.test_sso]
disallow_untyped_defs = True

[mypy-tests.handlers.test_user_directory]
[mypy-tests.handlers.*]
disallow_untyped_defs = True

[mypy-tests.metrics.test_background_process_metrics]


+ 1
- 1
synapse/handlers/auth.py View File

@@ -2031,7 +2031,7 @@ class PasswordAuthProvider:
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []

# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Iterable[str]] = {}
self._supported_login_types: Dict[str, Tuple[str, ...]] = {}

# Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}


+ 29
- 25
tests/handlers/test_appservice.py View File

@@ -31,7 +31,7 @@ from synapse.appservice import (
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
from synapse.types import RoomStreamToken
from synapse.types import JsonDict, RoomStreamToken
from synapse.util import Clock
from synapse.util.stringutils import random_string

@@ -44,7 +44,7 @@ from tests.utils import MockClock
class AppServiceHandlerTestCase(unittest.TestCase):
"""Tests the ApplicationServicesHandler."""

def setUp(self):
def setUp(self) -> None:
self.mock_store = Mock()
self.mock_as_api = Mock()
self.mock_scheduler = Mock()
@@ -61,7 +61,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.handler = ApplicationServicesHandler(hs)
self.event_source = hs.get_event_sources()

def test_notify_interested_services(self):
def test_notify_interested_services(self) -> None:
interested_service = self._mkservice(is_interested_in_event=True)
services = [
self._mkservice(is_interested_in_event=False),
@@ -90,7 +90,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service, events=[event]
)

def test_query_user_exists_unknown_user(self):
def test_query_user_exists_unknown_user(self) -> None:
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
@@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):

self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)

def test_query_user_exists_known_user(self):
def test_query_user_exists_known_user(self) -> None:
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
@@ -127,7 +127,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
"query_user called when it shouldn't have been.",
)

def test_query_room_alias_exists(self):
def test_query_room_alias_exists(self) -> None:
room_alias_str = "#foo:bar"
room_alias = Mock()
room_alias.to_string.return_value = room_alias_str
@@ -157,7 +157,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.assertEqual(result.room_id, room_id)
self.assertEqual(result.servers, servers)

def test_get_3pe_protocols_no_appservices(self):
def test_get_3pe_protocols_no_appservices(self) -> None:
self.mock_store.get_app_services.return_value = []
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
@@ -165,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.get_3pe_protocol.assert_not_called()
self.assertEqual(response, {})

def test_get_3pe_protocols_no_protocols(self):
def test_get_3pe_protocols_no_protocols(self) -> None:
service = self._mkservice(False, [])
self.mock_store.get_app_services.return_value = [service]
response = self.successResultOf(
@@ -174,7 +174,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.get_3pe_protocol.assert_not_called()
self.assertEqual(response, {})

def test_get_3pe_protocols_protocol_no_response(self):
def test_get_3pe_protocols_protocol_no_response(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
@@ -186,7 +186,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
)
self.assertEqual(response, {})

def test_get_3pe_protocols_select_one_protocol(self):
def test_get_3pe_protocols_select_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
@@ -202,7 +202,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
)

def test_get_3pe_protocols_one_protocol(self):
def test_get_3pe_protocols_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
@@ -218,7 +218,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
)

def test_get_3pe_protocols_multiple_protocol(self):
def test_get_3pe_protocols_multiple_protocol(self) -> None:
service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["other-protocol"])
self.mock_store.get_app_services.return_value = [service_one, service_two]
@@ -237,11 +237,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
},
)

def test_get_3pe_protocols_multiple_info(self):
def test_get_3pe_protocols_multiple_info(self) -> None:
service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["my-protocol"])

async def get_3pe_protocol(service, unusedProtocol):
async def get_3pe_protocol(
service: ApplicationService, protocol: str
) -> Optional[JsonDict]:
if service == service_one:
return {
"x-protocol-data": 42,
@@ -276,7 +278,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
},
)

def test_notify_interested_services_ephemeral(self):
def test_notify_interested_services_ephemeral(self) -> None:
"""
Test sending ephemeral events to the appservice handler are scheduled
to be pushed out to interested appservices, and that the stream ID is
@@ -306,7 +308,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
580,
)

def test_notify_interested_services_ephemeral_out_of_order(self):
def test_notify_interested_services_ephemeral_out_of_order(self) -> None:
"""
Test sending out of order ephemeral events to the appservice handler
are ignored.
@@ -390,7 +392,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipts.register_servlets,
]

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
@@ -417,7 +419,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)

def _notify_interested_services(self):
def _notify_interested_services(self) -> None:
# This is normally set in `notify_interested_services` but we need to call the
# internal async version so the reactor gets pushed to completion.
self.hs.get_application_service_handler().current_max += 1
@@ -443,7 +445,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
)
def test_match_interesting_room_members(
self, interesting_user: str, should_notify: bool
):
) -> None:
"""
Test to make sure that a interesting user (local or remote) in the room is
notified as expected when someone else in the room sends a message.
@@ -512,7 +514,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
else:
self.send_mock.assert_not_called()

def test_application_services_receive_events_sent_by_interesting_local_user(self):
def test_application_services_receive_events_sent_by_interesting_local_user(
self,
) -> None:
"""
Test to make sure that a messages sent from a local user can be interesting and
picked up by the appservice.
@@ -568,7 +572,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0]["type"], "m.room.message")
self.assertEqual(events[0]["sender"], alice)

def test_sending_read_receipt_batches_to_application_services(self):
def test_sending_read_receipt_batches_to_application_services(self) -> None:
"""Tests that a large batch of read receipts are sent correctly to
interested application services.
"""
@@ -644,7 +648,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
def test_application_services_receive_local_to_device(self):
def test_application_services_receive_local_to_device(self) -> None:
"""
Test that when a user sends a to-device message to another user
that is an application service's user namespace, the
@@ -722,7 +726,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
def test_application_services_receive_bursts_of_to_device(self):
def test_application_services_receive_bursts_of_to_device(self) -> None:
"""
Test that when a user sends >100 to-device messages at once, any
interested AS's will receive them in separate transactions.
@@ -913,7 +917,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
experimental_feature_enabled: bool,
as_supports_txn_extensions: bool,
as_should_receive_device_list_updates: bool,
):
) -> None:
"""
Tests that an application service receives notice of changed device
lists for a user, when a user changes their device lists.
@@ -1070,7 +1074,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
and a room for the users to talk in.
"""

async def preparation():
async def preparation() -> None:
await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
await self._add_fallback_key_for_device(
self._sender_user, self._sender_device, used=True


+ 1
- 1
tests/handlers/test_cas.py View File

@@ -199,7 +199,7 @@ class CasHandlerTestCase(HomeserverTestCase):
)


def _mock_request():
def _mock_request() -> Mock:
"""Returns a mock which will stand in as a SynapseRequest"""
mock = Mock(
spec=[


+ 15
- 12
tests/handlers/test_directory.py View File

@@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.api.errors
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.rest.client import directory, login, room
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, create_requester
@@ -201,7 +202,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
self.test_user_tok = self.login("user", "pass")
self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)

def _create_alias(self, user) -> None:
def _create_alias(self, user: str) -> None:
# Create a new alias to this room.
self.get_success(
self.store.create_room_alias_association(
@@ -324,7 +325,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)
return room_alias

def _set_canonical_alias(self, content) -> None:
def _set_canonical_alias(self, content: JsonDict) -> None:
"""Configure the canonical alias state on the room."""
self.helper.send_state(
self.room_id,
@@ -333,13 +334,15 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
tok=self.admin_user_tok,
)

def _get_canonical_alias(self):
def _get_canonical_alias(self) -> EventBase:
"""Get the canonical alias state of the room."""
return self.get_success(
result = self.get_success(
self._storage_controllers.state.get_current_state_event(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
assert result is not None
return result

def test_remove_alias(self) -> None:
"""Removing an alias that is the canonical alias should remove it there too."""
@@ -349,8 +352,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)

data = self._get_canonical_alias()
self.assertEqual(data["content"]["alias"], self.test_alias)
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
self.assertEqual(data.content["alias"], self.test_alias)
self.assertEqual(data.content["alt_aliases"], [self.test_alias])

# Finally, delete the alias.
self.get_success(
@@ -360,8 +363,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)

data = self._get_canonical_alias()
self.assertNotIn("alias", data["content"])
self.assertNotIn("alt_aliases", data["content"])
self.assertNotIn("alias", data.content)
self.assertNotIn("alt_aliases", data.content)

def test_remove_other_alias(self) -> None:
"""Removing an alias listed as in alt_aliases should remove it there too."""
@@ -378,9 +381,9 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)

data = self._get_canonical_alias()
self.assertEqual(data["content"]["alias"], self.test_alias)
self.assertEqual(data.content["alias"], self.test_alias)
self.assertEqual(
data["content"]["alt_aliases"], [self.test_alias, other_test_alias]
data.content["alt_aliases"], [self.test_alias, other_test_alias]
)

# Delete the second alias.
@@ -391,8 +394,8 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
)

data = self._get_canonical_alias()
self.assertEqual(data["content"]["alias"], self.test_alias)
self.assertEqual(data["content"]["alt_aliases"], [self.test_alias])
self.assertEqual(data.content["alias"], self.test_alias)
self.assertEqual(data.content["alt_aliases"], [self.test_alias])


class TestCreateAliasACL(unittest.HomeserverTestCase):


+ 44
- 32
tests/handlers/test_e2e_room_keys.py View File

@@ -17,7 +17,11 @@
import copy
from unittest import mock

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.errors import SynapseError
from synapse.server import HomeServer
from synapse.util import Clock

from tests import unittest

@@ -39,14 +43,14 @@ room_keys = {


class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(replication_layer=mock.Mock())

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_room_keys_handler()
self.local_user = "@boris:" + hs.hostname

def test_get_missing_current_version_info(self):
def test_get_missing_current_version_info(self) -> None:
"""Check that we get a 404 if we ask for info about the current version
if there is no version.
"""
@@ -56,7 +60,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)

def test_get_missing_version_info(self):
def test_get_missing_version_info(self) -> None:
"""Check that we get a 404 if we ask for info about a specific version
if it doesn't exist.
"""
@@ -67,9 +71,9 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)

def test_create_version(self):
def test_create_version(self) -> None:
"""Check that we can create and then retrieve versions."""
res = self.get_success(
version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -78,7 +82,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
)
self.assertEqual(res, "1")
self.assertEqual(version, "1")

# check we can retrieve it as the current version
res = self.get_success(self.handler.get_version_info(self.local_user))
@@ -110,7 +114,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
)

# upload a new one...
res = self.get_success(
version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -119,7 +123,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
)
self.assertEqual(res, "2")
self.assertEqual(version, "2")

# check we can retrieve it as the current version
res = self.get_success(self.handler.get_version_info(self.local_user))
@@ -134,7 +138,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)

def test_update_version(self):
def test_update_version(self) -> None:
"""Check that we can update versions."""
version = self.get_success(
self.handler.create_version(
@@ -173,7 +177,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)

def test_update_missing_version(self):
def test_update_missing_version(self) -> None:
"""Check that we get a 404 on updating nonexistent versions"""
e = self.get_failure(
self.handler.update_version(
@@ -190,7 +194,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)

def test_update_omitted_version(self):
def test_update_omitted_version(self) -> None:
"""Check that the update succeeds if the version is missing from the body"""
version = self.get_success(
self.handler.create_version(
@@ -227,7 +231,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)

def test_update_bad_version(self):
def test_update_bad_version(self) -> None:
"""Check that we get a 400 if the version in the body doesn't match"""
version = self.get_success(
self.handler.create_version(
@@ -255,7 +259,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 400)

def test_delete_missing_version(self):
def test_delete_missing_version(self) -> None:
"""Check that we get a 404 on deleting nonexistent versions"""
e = self.get_failure(
self.handler.delete_version(self.local_user, "1"), SynapseError
@@ -263,15 +267,15 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)

def test_delete_missing_current_version(self):
def test_delete_missing_current_version(self) -> None:
"""Check that we get a 404 on deleting nonexistent current version"""
e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
res = e.value.code
self.assertEqual(res, 404)

def test_delete_version(self):
def test_delete_version(self) -> None:
"""Check that we can create and then delete versions."""
res = self.get_success(
version = self.get_success(
self.handler.create_version(
self.local_user,
{
@@ -280,7 +284,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
},
)
)
self.assertEqual(res, "1")
self.assertEqual(version, "1")

# check we can delete it
self.get_success(self.handler.delete_version(self.local_user, "1"))
@@ -292,7 +296,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)

def test_get_missing_backup(self):
def test_get_missing_backup(self) -> None:
"""Check that we get a 404 on querying missing backup"""
e = self.get_failure(
self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
@@ -300,7 +304,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)

def test_get_missing_room_keys(self):
def test_get_missing_room_keys(self) -> None:
"""Check we get an empty response from an empty backup"""
version = self.get_success(
self.handler.create_version(
@@ -319,7 +323,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
# TODO: test the locking semantics when uploading room_keys,
# although this is probably best done in sytest

def test_upload_room_keys_no_versions(self):
def test_upload_room_keys_no_versions(self) -> None:
"""Check that we get a 404 on uploading keys when no versions are defined"""
e = self.get_failure(
self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
@@ -328,7 +332,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)

def test_upload_room_keys_bogus_version(self):
def test_upload_room_keys_bogus_version(self) -> None:
"""Check that we get a 404 on uploading keys when an nonexistent version
is specified
"""
@@ -350,7 +354,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 404)

def test_upload_room_keys_wrong_version(self):
def test_upload_room_keys_wrong_version(self) -> None:
"""Check that we get a 403 on uploading keys for an old version"""
version = self.get_success(
self.handler.create_version(
@@ -380,7 +384,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
res = e.value.code
self.assertEqual(res, 403)

def test_upload_room_keys_insert(self):
def test_upload_room_keys_insert(self) -> None:
"""Check that we can insert and retrieve keys for a session"""
version = self.get_success(
self.handler.create_version(
@@ -416,7 +420,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
)
self.assertDictEqual(res, room_keys)

def test_upload_room_keys_merge(self):
def test_upload_room_keys_merge(self) -> None:
"""Check that we can upload a new room_key for an existing session and
have it correctly merged"""
version = self.get_success(
@@ -449,9 +453,11 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)

res = self.get_success(self.handler.get_room_keys(self.local_user, version))
res_keys = self.get_success(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK",
)

@@ -465,9 +471,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)

res = self.get_success(self.handler.get_room_keys(self.local_user, version))
res_keys = self.get_success(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"new",
)

# the etag should NOT be equal now, since the key changed
@@ -483,9 +492,12 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)

res = self.get_success(self.handler.get_room_keys(self.local_user, version))
res_keys = self.get_success(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
res_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"new",
)

# the etag should be the same since the session did not change
@@ -494,7 +506,7 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):

# TODO: check edge cases as well as the common variations here

def test_delete_room_keys(self):
def test_delete_room_keys(self) -> None:
"""Check that we can insert and delete keys for a session"""
version = self.get_success(
self.handler.create_version(


+ 1
- 1
tests/handlers/test_federation.py View File

@@ -439,7 +439,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")

def create_invite():
def create_invite() -> EventBase:
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
room_version = self.get_success(self.store.get_room_version(room_id))
return event_from_pdu_json(


+ 8
- 2
tests/handlers/test_federation_event.py View File

@@ -14,6 +14,8 @@
from typing import Optional
from unittest import mock

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.errors import AuthError, StoreError
from synapse.api.room_versions import RoomVersion
from synapse.event_auth import (
@@ -26,8 +28,10 @@ from synapse.federation.transport.client import StateRequestResponse
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict
from synapse.util import Clock

from tests import unittest
from tests.test_utils import event_injection, make_awaitable
@@ -40,7 +44,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
room.register_servlets,
]

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# mock out the federation transport client
self.mock_federation_transport_client = mock.Mock(
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
@@ -165,7 +169,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
else:

async def get_event(destination: str, event_id: str, timeout=None):
async def get_event(
destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict:
self.assertEqual(destination, self.OTHER_SERVER_NAME)
self.assertEqual(event_id, prev_event.event_id)
return {"pdus": [prev_event.get_pdu_json()]}


+ 16
- 10
tests/handlers/test_message.py View File

@@ -14,12 +14,16 @@
import logging
from typing import Tuple

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import create_requester
from synapse.util import Clock
from synapse.util.stringutils import random_string

from tests import unittest
@@ -35,7 +39,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_event_creation_handler()
self._persist_event_storage_controller = (
self.hs.get_storage_controllers().persistence
@@ -94,7 +98,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
)
)

def test_duplicated_txn_id(self):
def test_duplicated_txn_id(self) -> None:
"""Test that attempting to handle/persist an event with a transaction ID
that has already been persisted correctly returns the old event and does
*not* produce duplicate messages.
@@ -161,7 +165,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# rather than the new one.
self.assertEqual(ret_event1.event_id, ret_event4.event_id)

def test_duplicated_txn_id_one_call(self):
def test_duplicated_txn_id_one_call(self) -> None:
"""Test that we correctly handle duplicates that we try and persist at
the same time.
"""
@@ -185,7 +189,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[0].event_id, events[1].event_id)

def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self):
def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(
self,
) -> None:
"""When we set allow_no_prev_events=True, should be able to create a
event without any prev_events (only auth_events).
"""
@@ -214,7 +220,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):

def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events(
self,
):
) -> None:
"""When we set allow_no_prev_events=False, shouldn't be able to create a
event without any prev_events even if it has auth_events. Expect an
exception to be raised.
@@ -245,7 +251,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):

def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events(
self,
):
) -> None:
"""When we set allow_no_prev_events=True, should be able to create a
event without any prev_events or auth_events. Expect an exception to be
raised.
@@ -277,12 +283,12 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)

def test_allow_server_acl(self):
def test_allow_server_acl(self) -> None:
"""Test that sending an ACL that blocks everyone but ourselves works."""

self.helper.send_state(
@@ -293,7 +299,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
expect_code=200,
)

def test_deny_server_acl_block_outselves(self):
def test_deny_server_acl_block_outselves(self) -> None:
"""Test that sending an ACL that blocks ourselves does not work."""
self.helper.send_state(
self.room_id,
@@ -303,7 +309,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
expect_code=400,
)

def test_deny_redact_server_acl(self):
def test_deny_redact_server_acl(self) -> None:
"""Test that attempting to redact an ACL is blocked."""

body = self.helper.send_state(


+ 31
- 17
tests/handlers/test_oidc.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, Tuple
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse

@@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.types import JsonDict, UserID
from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from synapse.util.stringutils import random_string
@@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config

try:
import authlib # noqa: F401
from authlib.oidc.core import UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata

from synapse.handlers.oidc import Token, UserAttributeDict

HAS_OIDC = True
except ImportError:
@@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = {

class TestMappingProvider:
@staticmethod
def parse_config(config):
return
def parse_config(config: JsonDict) -> None:
return None

def __init__(self, config):
def __init__(self, config: None):
pass

def get_remote_user_id(self, userinfo):
def get_remote_user_id(self, userinfo: "UserInfo") -> str:
return userinfo["sub"]

async def map_user_attributes(self, userinfo, token):
return {"localpart": userinfo["username"], "display_name": None}
async def map_user_attributes(
self, userinfo: "UserInfo", token: "Token"
) -> "UserAttributeDict":
# This is testing not providing the full map.
return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item]

# Do not include get_extra_attributes to test backwards compatibility paths.


class TestMappingProviderExtra(TestMappingProvider):
async def get_extra_attributes(self, userinfo, token):
async def get_extra_attributes(
self, userinfo: "UserInfo", token: "Token"
) -> JsonDict:
return {"phone": userinfo["phone"]}


class TestMappingProviderFailures(TestMappingProvider):
async def map_user_attributes(self, userinfo, token, failures):
return {
# Superclass is testing the legacy interface for map_user_attributes.
async def map_user_attributes( # type: ignore[override]
self, userinfo: "UserInfo", token: "Token", failures: int
) -> "UserAttributeDict":
return { # type: ignore[typeddict-item]
"localpart": userinfo["username"] + (str(failures) if failures else ""),
"display_name": None,
}
@@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.hs_patcher.stop()
return super().tearDown()

def reset_mocks(self):
def reset_mocks(self) -> None:
"""Reset all the Mocks."""
self.fake_server.reset_mocks()
self.render_error.reset_mock()
self.complete_sso_login.reset_mock()

def metadata_edit(self, values):
def metadata_edit(self, values: dict) -> ContextManager[Mock]:
"""Modify the result that will be returned by the well-known query"""

metadata = self.fake_server.get_metadata()
@@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
return _build_callback_request(code, state, session), grant

def assertRenderedError(self, error, error_description=None):
def assertRenderedError(
self, error: str, error_description: Optional[str] = None
) -> Tuple[Any, ...]:
self.render_error.assert_called_once()
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
@@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Provider metadatas are extensively validated."""
h = self.provider

def force_load_metadata():
async def force_load():
def force_load_metadata() -> Awaitable[None]:
async def force_load() -> "OpenIDProviderMetadata":
return await h.load_metadata(force=True)

return get_awaitable_result(force_load())
@@ -1198,7 +1212,7 @@ def _build_callback_request(
state: str,
session: str,
ip_address: str = "10.0.0.1",
):
) -> Mock:
"""Builds a fake SynapseRequest to mock the browser callback

Returns a Mock object which looks like the SynapseRequest we get from a browser


+ 74
- 70
tests/handlers/test_password_providers.py View File

@@ -15,12 +15,13 @@
"""Tests for the password_auth_provider interface"""

from http import HTTPStatus
from typing import Any, Type, Union
from typing import Any, Dict, List, Optional, Type, Union
from unittest.mock import Mock

import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.handlers.account import AccountHandler
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
@@ -44,13 +45,13 @@ class LegacyPasswordOnlyAuthProvider:
"""A legacy password_provider which only implements `check_password`."""

@staticmethod
def parse_config(self):
def parse_config(config: JsonDict) -> None:
pass

def __init__(self, config, account_handler):
def __init__(self, config: None, account_handler: AccountHandler):
pass

def check_password(self, *args):
def check_password(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args)


@@ -58,16 +59,16 @@ class LegacyCustomAuthProvider:
"""A legacy password_provider which implements a custom login type."""

@staticmethod
def parse_config(self):
def parse_config(config: JsonDict) -> None:
pass

def __init__(self, config, account_handler):
def __init__(self, config: None, account_handler: AccountHandler):
pass

def get_supported_login_types(self):
def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"test.login_type": ["test_field"]}

def check_auth(self, *args):
def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)


@@ -75,15 +76,15 @@ class CustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type."""

@staticmethod
def parse_config(self):
def parse_config(config: JsonDict) -> None:
pass

def __init__(self, config, api: ModuleApi):
def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
)

def check_auth(self, *args):
def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args)


@@ -92,16 +93,16 @@ class LegacyPasswordCustomAuthProvider:
as a custom type."""

@staticmethod
def parse_config(self):
def parse_config(config: JsonDict) -> None:
pass

def __init__(self, config, account_handler):
def __init__(self, config: None, account_handler: AccountHandler):
pass

def get_supported_login_types(self):
def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}

def check_auth(self, *args):
def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)


@@ -110,10 +111,10 @@ class PasswordCustomAuthProvider:
as well as a password login"""

@staticmethod
def parse_config(self):
def parse_config(config: JsonDict) -> None:
pass

def __init__(self, config, api: ModuleApi):
def __init__(self, config: None, api: ModuleApi):
api.register_password_auth_provider_callbacks(
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
@@ -121,10 +122,10 @@ class PasswordCustomAuthProvider:
}
)

def check_auth(self, *args):
def check_auth(self, *args: Any) -> Mock:
return mock_password_provider.check_auth(*args)

def check_pass(self, *args):
def check_pass(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args)


@@ -161,16 +162,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
CALLBACK_USERNAME = "get_username_for_registration"
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"

def setUp(self):
def setUp(self) -> None:
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
super().setUp()

@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self):
def test_password_only_auth_progiver_login_legacy(self) -> None:
self.password_only_auth_provider_login_test_body()

def password_only_auth_provider_login_test_body(self):
def password_only_auth_provider_login_test_body(self) -> None:
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@@ -201,10 +202,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)

@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth_legacy(self):
def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
self.password_only_auth_provider_ui_auth_test_body()

def password_only_auth_provider_ui_auth_test_body(self):
def password_only_auth_provider_ui_auth_test_body(self) -> None:
"""UI Auth should delegate correctly to the password provider"""

# create the user, otherwise access doesn't work
@@ -238,10 +239,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")

@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_login_legacy(self):
def test_local_user_fallback_login_legacy(self) -> None:
self.local_user_fallback_login_test_body()

def local_user_fallback_login_test_body(self):
def local_user_fallback_login_test_body(self) -> None:
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")

@@ -255,10 +256,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual("@localuser:test", channel.json_body["user_id"])

@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_ui_auth_legacy(self):
def test_local_user_fallback_ui_auth_legacy(self) -> None:
self.local_user_fallback_ui_auth_test_body()

def local_user_fallback_ui_auth_test_body(self):
def local_user_fallback_ui_auth_test_body(self) -> None:
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")

@@ -298,10 +299,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_login_legacy(self):
def test_no_local_user_fallback_login_legacy(self) -> None:
self.no_local_user_fallback_login_test_body()

def no_local_user_fallback_login_test_body(self):
def no_local_user_fallback_login_test_body(self) -> None:
"""localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")

@@ -320,10 +321,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_ui_auth_legacy(self):
def test_no_local_user_fallback_ui_auth_legacy(self) -> None:
self.no_local_user_fallback_ui_auth_test_body()

def no_local_user_fallback_ui_auth_test_body(self):
def no_local_user_fallback_ui_auth_test_body(self) -> None:
"""localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")

@@ -361,10 +362,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_password_auth_disabled_legacy(self):
def test_password_auth_disabled_legacy(self) -> None:
self.password_auth_disabled_test_body()

def password_auth_disabled_test_body(self):
def password_auth_disabled_test_body(self) -> None:
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
@@ -376,14 +377,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called()

@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_login_legacy(self):
def test_custom_auth_provider_login_legacy(self) -> None:
self.custom_auth_provider_login_test_body()

@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self):
def test_custom_auth_provider_login(self) -> None:
self.custom_auth_provider_login_test_body()

def custom_auth_provider_login_test_body(self):
def custom_auth_provider_login_test_body(self) -> None:
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
@@ -424,14 +425,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)

@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_ui_auth_legacy(self):
def test_custom_auth_provider_ui_auth_legacy(self) -> None:
self.custom_auth_provider_ui_auth_test_body()

@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self):
def test_custom_auth_provider_ui_auth(self) -> None:
self.custom_auth_provider_ui_auth_test_body()

def custom_auth_provider_ui_auth_test_body(self):
def custom_auth_provider_ui_auth_test_body(self) -> None:
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
@@ -486,14 +487,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)

@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_callback_legacy(self):
def test_custom_auth_provider_callback_legacy(self) -> None:
self.custom_auth_provider_callback_test_body()

@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self):
def test_custom_auth_provider_callback(self) -> None:
self.custom_auth_provider_callback_test_body()

def custom_auth_provider_callback_test_body(self):
def custom_auth_provider_callback_test_body(self) -> None:
callback = Mock(return_value=make_awaitable(None))

mock_password_provider.check_auth.return_value = make_awaitable(
@@ -521,16 +522,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_custom_auth_password_disabled_legacy(self):
def test_custom_auth_password_disabled_legacy(self) -> None:
self.custom_auth_password_disabled_test_body()

@override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
def test_custom_auth_password_disabled(self):
def test_custom_auth_password_disabled(self) -> None:
self.custom_auth_password_disabled_test_body()

def custom_auth_password_disabled_test_body(self):
def custom_auth_password_disabled_test_body(self) -> None:
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")

@@ -548,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None:
self.custom_auth_password_disabled_localdb_enabled_test_body()

@override_config(
@@ -557,10 +558,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled(self):
def test_custom_auth_password_disabled_localdb_enabled(self) -> None:
self.custom_auth_password_disabled_localdb_enabled_test_body()

def custom_auth_password_disabled_localdb_enabled_test_body(self):
def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None:
"""Check the localdb_enabled == enabled == False

Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@@ -583,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login_legacy(self):
def test_password_custom_auth_password_disabled_login_legacy(self) -> None:
self.password_custom_auth_password_disabled_login_test_body()

@override_config(
@@ -592,10 +593,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login(self):
def test_password_custom_auth_password_disabled_login(self) -> None:
self.password_custom_auth_password_disabled_login_test_body()

def password_custom_auth_password_disabled_login_test_body(self):
def password_custom_auth_password_disabled_login_test_body(self) -> None:
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
@@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None:
self.password_custom_auth_password_disabled_ui_auth_test_body()

@override_config(
@@ -624,10 +625,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth(self):
def test_password_custom_auth_password_disabled_ui_auth(self) -> None:
self.password_custom_auth_password_disabled_ui_auth_test_body()

def password_custom_auth_password_disabled_ui_auth_test_body(self):
def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None:
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
@@ -689,7 +690,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback_legacy(self):
def test_custom_auth_no_local_user_fallback_legacy(self) -> None:
self.custom_auth_no_local_user_fallback_test_body()

@override_config(
@@ -698,10 +699,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback(self):
def test_custom_auth_no_local_user_fallback(self) -> None:
self.custom_auth_no_local_user_fallback_test_body()

def custom_auth_no_local_user_fallback_test_body(self):
def custom_auth_no_local_user_fallback_test_body(self) -> None:
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")

@@ -713,14 +714,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)

def test_on_logged_out(self):
def test_on_logged_out(self) -> None:
"""Tests that the on_logged_out callback is called when the user logs out."""
self.register_user("rin", "password")
tok = self.login("rin", "password")

self.called = False

async def on_logged_out(user_id, device_id, access_token):
async def on_logged_out(
user_id: str, device_id: Optional[str], access_token: str
) -> None:
self.called = True

on_logged_out = Mock(side_effect=on_logged_out)
@@ -738,7 +741,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
on_logged_out.assert_called_once()
self.assertTrue(self.called)

def test_username(self):
def test_username(self) -> None:
"""Tests that the get_username_for_registration callback can define the username
of a user when registering.
"""
@@ -763,7 +766,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mxid = channel.json_body["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")

def test_username_uia(self):
def test_username_uia(self) -> None:
"""Tests that the get_username_for_registration callback is only called at the
end of the UIA flow.
"""
@@ -782,7 +785,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):

# Set some email configuration so the test doesn't fail because of its absence.
@override_config({"email": {"notif_from": "noreply@test"}})
def test_3pid_allowed(self):
def test_3pid_allowed(self) -> None:
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
the 3PID. Also checks that the module is passed a boolean indicating whether the
@@ -791,7 +794,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self._test_3pid_allowed("rin", False)
self._test_3pid_allowed("kitay", True)

def test_displayname(self):
def test_displayname(self) -> None:
"""Tests that the get_displayname_for_registration callback can define the
display name of a user when registering.
"""
@@ -820,7 +823,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):

self.assertEqual(display_name, username + "-foo")

def test_displayname_uia(self):
def test_displayname_uia(self) -> None:
"""Tests that the get_displayname_for_registration callback is only called at the
end of the UIA flow.
"""
@@ -841,7 +844,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# Check that the callback has been called.
m.assert_called_once()

def _test_3pid_allowed(self, username: str, registration: bool):
def _test_3pid_allowed(self, username: str, registration: bool) -> None:
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments.

@@ -907,7 +910,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
client is trying to register.
"""

async def callback(uia_results, params):
async def callback(uia_results: JsonDict, params: JsonDict) -> str:
self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"]
return username + "-foo"
@@ -950,12 +953,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
def _send_password_login(self, user: str, password: str) -> FakeChannel:
return self._send_login(type="m.login.password", user=user, password=password)

def _send_login(self, type, user, **params) -> FakeChannel:
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel:
params = {"identifier": {"type": "m.id.user", "user": user}, "type": type}
params.update(extra_params)
channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel

def _start_delete_device_session(self, access_token, device_id) -> str:
def _start_delete_device_session(self, access_token: str, device_id: str) -> str:
"""Make an initial delete device request, and return the UI Auth session ID"""
channel = self._delete_device(access_token, device_id)
self.assertEqual(channel.code, 401)


+ 54
- 46
tests/handlers/test_presence.py View File

@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Optional, cast
from unittest.mock import Mock, call

from parameterized import parameterized
from signedjson.key import generate_signing_key

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -35,7 +37,9 @@ from synapse.handlers.presence import (
)
from synapse.rest import admin
from synapse.rest.client import room
from synapse.types import UserID, get_domain_from_id
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util import Clock

from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -44,10 +48,12 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets]

def prepare(self, reactor, clock, homeserver):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.store = homeserver.get_datastores().main

def test_offline_to_online(self):
def test_offline_to_online(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -85,7 +91,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)

def test_online_to_online(self):
def test_online_to_online(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -128,7 +134,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)

def test_online_to_online_last_active_noop(self):
def test_online_to_online_last_active_noop(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -173,7 +179,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)

def test_online_to_online_last_active(self):
def test_online_to_online_last_active(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -210,7 +216,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)

def test_remote_ping_timer(self):
def test_remote_ping_timer(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -244,7 +250,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)

def test_online_to_offline(self):
def test_online_to_offline(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -266,7 +272,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):

self.assertEqual(wheel_timer.insert.call_count, 0)

def test_online_to_idle(self):
def test_online_to_idle(self) -> None:
wheel_timer = Mock()
user_id = "@foo:bar"
now = 5000000
@@ -300,7 +306,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
any_order=True,
)

def test_persisting_presence_updates(self):
def test_persisting_presence_updates(self) -> None:
"""Tests that the latest presence state for each user is persisted correctly"""
# Create some test users and presence states for them
presence_states = []
@@ -322,7 +328,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.update_presence(presence_states))

# Check that each update is present in the database
db_presence_states = self.get_success(
db_presence_states_raw = self.get_success(
self.store.get_all_presence_updates(
instance_name="master",
last_id=0,
@@ -332,7 +338,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
)

# Extract presence update user ID and state information into lists of tuples
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states_raw[0]]
presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]

# Compare what we put into the storage with what we got out.
@@ -343,7 +349,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
class PresenceTimeoutTestCase(unittest.TestCase):
"""Tests different timers and that the timer does not change `status_msg` of user."""

def test_idle_timer(self):
def test_idle_timer(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -363,7 +369,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
self.assertEqual(new_state.status_msg, status_msg)

def test_busy_no_idle(self):
def test_busy_no_idle(self) -> None:
"""
Tests that a user setting their presence to busy but idling doesn't turn their
presence state into unavailable.
@@ -387,7 +393,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.BUSY)
self.assertEqual(new_state.status_msg, status_msg)

def test_sync_timeout(self):
def test_sync_timeout(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -407,7 +413,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)

def test_sync_online(self):
def test_sync_online(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -429,7 +435,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.ONLINE)
self.assertEqual(new_state.status_msg, status_msg)

def test_federation_ping(self):
def test_federation_ping(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -448,7 +454,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEqual(state, new_state)

def test_no_timeout(self):
def test_no_timeout(self) -> None:
user_id = "@foo:bar"
now = 5000000

@@ -464,7 +470,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):

self.assertIsNone(new_state)

def test_federation_timeout(self):
def test_federation_timeout(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -487,7 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(new_state.state, PresenceState.OFFLINE)
self.assertEqual(new_state.status_msg, status_msg)

def test_last_active(self):
def test_last_active(self) -> None:
user_id = "@foo:bar"
status_msg = "I'm here!"
now = 5000000
@@ -508,15 +514,15 @@ class PresenceTimeoutTestCase(unittest.TestCase):


class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()

def test_external_process_timeout(self):
def test_external_process_timeout(self) -> None:
"""Test that if an external process doesn't update the records for a while
we time out their syncing users presence.
"""
process_id = 1
process_id = "1"
user_id = "@test:server"

# Notify handler that a user is now syncing.
@@ -544,7 +550,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
)
self.assertEqual(state.state, PresenceState.OFFLINE)

def test_user_goes_offline_by_timeout_status_msg_remain(self):
def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None:
"""Test that if a user doesn't update the records for a while
users presence goes `OFFLINE` because of timeout and `status_msg` remains.
"""
@@ -576,7 +582,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, status_msg)

def test_user_goes_offline_manually_with_no_status_msg(self):
def test_user_goes_offline_manually_with_no_status_msg(self) -> None:
"""Test that if a user change presence manually to `OFFLINE`
and no status is set, that `status_msg` is `None`.
"""
@@ -601,7 +607,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.OFFLINE)
self.assertEqual(state.status_msg, None)

def test_user_goes_offline_manually_with_status_msg(self):
def test_user_goes_offline_manually_with_status_msg(self) -> None:
"""Test that if a user change presence manually to `OFFLINE`
and a status is set, that `status_msg` appears.
"""
@@ -618,7 +624,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
user_id, PresenceState.OFFLINE, "And now here."
)

def test_user_reset_online_with_no_status(self):
def test_user_reset_online_with_no_status(self) -> None:
"""Test that if a user set again the presence manually
and no status is set, that `status_msg` is `None`.
"""
@@ -644,7 +650,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(state.state, PresenceState.ONLINE)
self.assertEqual(state.status_msg, None)

def test_set_presence_with_status_msg_none(self):
def test_set_presence_with_status_msg_none(self) -> None:
"""Test that if a user set again the presence manually
and status is `None`, that `status_msg` is `None`.
"""
@@ -659,7 +665,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# Mark user as online and `status_msg = None`
self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)

def test_set_presence_from_syncing_not_set(self):
def test_set_presence_from_syncing_not_set(self) -> None:
"""Test that presence is not set by syncing if affect_presence is false"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -680,7 +686,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# and status message should still be the same
self.assertEqual(state.status_msg, status_msg)

def test_set_presence_from_syncing_is_set(self):
def test_set_presence_from_syncing_is_set(self) -> None:
"""Test that presence is set by syncing if affect_presence is true"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -699,7 +705,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# we should now be online
self.assertEqual(state.state, PresenceState.ONLINE)

def test_set_presence_from_syncing_keeps_status(self):
def test_set_presence_from_syncing_keeps_status(self) -> None:
"""Test that presence set by syncing retains status message"""
user_id = "@test:server"
status_msg = "I'm here!"
@@ -726,7 +732,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
},
}
)
def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool):
def test_set_presence_from_syncing_keeps_busy(
self, test_with_workers: bool
) -> None:
"""Test that presence set by syncing doesn't affect busy status

Args:
@@ -767,7 +775,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):

def _set_presencestate_with_status_msg(
self, user_id: str, state: str, status_msg: Optional[str]
):
) -> None:
"""Set a PresenceState and status_msg and check the result.

Args:
@@ -790,14 +798,14 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):


class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.instance_name = hs.get_instance_name()

self.queue = self.presence_handler.get_federation_queue()

def test_send_and_get(self):
def test_send_and_get(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -834,7 +842,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
self.assertFalse(limited)
self.assertCountEqual(rows, [])

def test_send_and_get_split(self):
def test_send_and_get_split(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -877,7 +885,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):

self.assertCountEqual(rows, expected_rows)

def test_clear_queue_all(self):
def test_clear_queue_all(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -921,7 +929,7 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):

self.assertCountEqual(rows, expected_rows)

def test_partially_clear_queue(self):
def test_partially_clear_queue(self) -> None:
state1 = UserPresenceState.default("@user1:test")
state2 = UserPresenceState.default("@user2:test")
state3 = UserPresenceState.default("@user3:test")
@@ -982,7 +990,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):

servlets = [room.register_servlets]

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(
"server",
federation_http_client=None,
@@ -990,14 +998,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
)
return hs

def default_config(self):
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable federation sending on the main process.
config["federation_sender_instances"] = None
return config

def prepare(self, reactor, clock, hs):
self.federation_sender = hs.get_federation_sender()
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.federation_sender = cast(Mock, hs.get_federation_sender())
self.event_builder_factory = hs.get_event_builder_factory()
self.federation_event_handler = hs.get_federation_event_handler()
self.presence_handler = hs.get_presence_handler()
@@ -1013,7 +1021,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# random key to use.
self.random_signing_key = generate_signing_key("ver")

def test_remote_joins(self):
def test_remote_joins(self) -> None:
# We advance time to something that isn't 0, as we use 0 as a special
# value.
self.reactor.advance(1000000000000)
@@ -1061,7 +1069,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
destinations={"server3"}, states=[expected_state]
)

def test_remote_gets_presence_when_local_user_joins(self):
def test_remote_gets_presence_when_local_user_joins(self) -> None:
# We advance time to something that isn't 0, as we use 0 as a special
# value.
self.reactor.advance(1000000000000)
@@ -1110,7 +1118,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
destinations={"server2", "server3"}, states=[expected_state]
)

def _add_new_user(self, room_id, user_id):
def _add_new_user(self, room_id: str, user_id: str) -> None:
"""Add new user to the room by creating an event and poking the federation API."""

hostname = get_domain_from_id(user_id)


+ 2
- 2
tests/handlers/test_profile.py View File

@@ -332,7 +332,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
@unittest.override_config(
{"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]}
)
def test_avatar_constraint_on_local_server_with_port(self):
def test_avatar_constraint_on_local_server_with_port(self) -> None:
"""Test that avatar metadata is correctly fetched when the media is on a local
server and the server has an explicit port.

@@ -376,7 +376,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc))
)

def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
"""Stores metadata about files in the database.

Args:


+ 5
- 1
tests/handlers/test_receipts.py View File

@@ -15,14 +15,18 @@
from copy import deepcopy
from typing import List

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock

from tests import unittest


class ReceiptsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.event_source = hs.get_event_sources().sources.receipt

def test_filters_out_private_receipt(self) -> None:


+ 97
- 72
tests/handlers/test_register.py View File

@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Collection, List, Optional, Tuple
from unittest.mock import Mock

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import (
@@ -22,8 +25,18 @@ from synapse.api.errors import (
ResourceLimitError,
SynapseError,
)
from synapse.module_api import ModuleApi
from synapse.server import HomeServer
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.types import (
JsonDict,
Requester,
RoomAlias,
RoomID,
UserID,
create_requester,
)
from synapse.util import Clock

from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -33,94 +46,98 @@ from .. import unittest


class TestSpamChecker:
def __init__(self, config, api):
def __init__(self, config: None, api: ModuleApi):
api.register_spam_checker_callbacks(
check_registration_for_spam=self.check_registration_for_spam,
)

@staticmethod
def parse_config(config):
return config
def parse_config(config: JsonDict) -> None:
return None

async def check_registration_for_spam(
self,
email_threepid,
username,
request_info,
auth_provider_id,
):
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> RegistrationBehaviour:
pass


class DenyAll(TestSpamChecker):
async def check_registration_for_spam(
self,
email_threepid,
username,
request_info,
auth_provider_id,
):
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> RegistrationBehaviour:
return RegistrationBehaviour.DENY


class BanAll(TestSpamChecker):
async def check_registration_for_spam(
self,
email_threepid,
username,
request_info,
auth_provider_id,
):
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> RegistrationBehaviour:
return RegistrationBehaviour.SHADOW_BAN


class BanBadIdPUser(TestSpamChecker):
async def check_registration_for_spam(
self, email_threepid, username, request_info, auth_provider_id=None
):
self,
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str] = None,
) -> RegistrationBehaviour:
# Reject any user coming from CAS and whose username contains profanity
if auth_provider_id == "cas" and "flimflob" in username:
if auth_provider_id == "cas" and username and "flimflob" in username:
return RegistrationBehaviour.DENY
return RegistrationBehaviour.ALLOW


class TestLegacyRegistrationSpamChecker:
def __init__(self, config, api):
def __init__(self, config: None, api: ModuleApi):
pass

async def check_registration_for_spam(
self,
email_threepid,
username,
request_info,
):
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
pass


class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
async def check_registration_for_spam(
self,
email_threepid,
username,
request_info,
):
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
return RegistrationBehaviour.ALLOW


class LegacyDenyAll(TestLegacyRegistrationSpamChecker):
async def check_registration_for_spam(
self,
email_threepid,
username,
request_info,
):
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
return RegistrationBehaviour.DENY


class RegistrationTestCase(unittest.HomeserverTestCase):
"""Tests the RegistrationHandler."""

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs_config = self.default_config()

# some of the tests rely on us having a user consent version
@@ -145,7 +162,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):

return hs

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_registration_handler()
self.store = self.hs.get_datastores().main
self.lots_of_users = 100
@@ -153,7 +170,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):

self.requester = create_requester("@requester:test")

def test_user_is_created_and_logged_in_if_doesnt_exist(self):
def test_user_is_created_and_logged_in_if_doesnt_exist(self) -> None:
frank = UserID.from_string("@frank:test")
user_id = frank.to_string()
requester = create_requester(user_id)
@@ -164,7 +181,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertIsInstance(result_token, str)
self.assertGreater(len(result_token), 20)

def test_if_user_exists(self):
def test_if_user_exists(self) -> None:
store = self.hs.get_datastores().main
frank = UserID.from_string("@frank:test")
self.get_success(
@@ -180,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertTrue(result_token is not None)

@override_config({"limit_usage_by_mau": False})
def test_mau_limits_when_disabled(self):
def test_mau_limits_when_disabled(self) -> None:
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))

@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self):
def test_get_or_create_user_mau_not_blocked(self) -> None:
self.store.count_monthly_users = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
@@ -193,7 +210,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.get_success(self.get_or_create_user(self.requester, "c", "User"))

@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self):
def test_get_or_create_user_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
@@ -211,7 +228,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)

@override_config({"limit_usage_by_mau": True})
def test_register_mau_blocked(self):
def test_register_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
@@ -229,7 +246,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config(
{"auto_join_rooms": ["#room:test"], "auto_join_rooms_for_guests": False}
)
def test_auto_join_rooms_for_guests(self):
def test_auto_join_rooms_for_guests(self) -> None:
user_id = self.get_success(
self.handler.register_user(localpart="jeff", make_guest=True),
)
@@ -237,7 +254,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 0)

@override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms(self):
def test_auto_create_auto_join_rooms(self) -> None:
room_alias_str = "#room:test"
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -249,7 +266,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 1)

@override_config({"auto_join_rooms": []})
def test_auto_create_auto_join_rooms_with_no_rooms(self):
def test_auto_create_auto_join_rooms_with_no_rooms(self) -> None:
frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string())
@@ -257,7 +274,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 0)

@override_config({"auto_join_rooms": ["#room:another"]})
def test_auto_create_auto_join_where_room_is_another_domain(self):
def test_auto_create_auto_join_where_room_is_another_domain(self) -> None:
frank = UserID.from_string("@frank:test")
user_id = self.get_success(self.handler.register_user(frank.localpart))
self.assertEqual(user_id, frank.to_string())
@@ -267,13 +284,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config(
{"auto_join_rooms": ["#room:test"], "autocreate_auto_join_rooms": False}
)
def test_auto_create_auto_join_where_auto_create_is_false(self):
def test_auto_create_auto_join_where_auto_create_is_false(self) -> None:
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)

@override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self):
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
room_alias_str = "#room:test"
self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
@@ -284,7 +301,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.get_failure(directory_handler.get_association(room_alias), SynapseError)

@override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test"

self.store.count_real_users = Mock(return_value=make_awaitable(1))
@@ -299,7 +316,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(rooms), 1)

@override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(self):
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
self,
) -> None:
self.store.count_real_users = Mock(return_value=make_awaitable(2))
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
@@ -312,7 +331,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"autocreate_auto_join_rooms_federated": False,
}
)
def test_auto_create_auto_join_rooms_federated(self):
def test_auto_create_auto_join_rooms_federated(self) -> None:
"""
Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it).
@@ -339,7 +358,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config(
{"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"}
)
def test_auto_join_mxid_localpart(self):
def test_auto_join_mxid_localpart(self) -> None:
"""
Ensure the user still needs up in the room created by a different user.
"""
@@ -376,7 +395,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support",
}
)
def test_auto_create_auto_join_room_preset(self):
def test_auto_create_auto_join_room_preset(self) -> None:
"""
Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it).
@@ -416,7 +435,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support",
}
)
def test_auto_create_auto_join_room_preset_guest(self):
def test_auto_create_auto_join_room_preset_guest(self) -> None:
"""
Auto-created rooms that are private require an invite to go to the user
(instead of directly joining it).
@@ -454,7 +473,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_mxid_localpart": "support",
}
)
def test_auto_create_auto_join_room_preset_invalid_permissions(self):
def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None:
"""
Auto-created rooms that are private require an invite, check that
registration doesn't completely break if the inviter doesn't have proper
@@ -525,7 +544,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"auto_join_rooms": ["#room:test"],
},
)
def test_auto_create_auto_join_where_no_consent(self):
def test_auto_create_auto_join_where_no_consent(self) -> None:
"""Test to ensure that the first user is not auto-joined to a room if
they have not given general consent.
"""
@@ -550,19 +569,19 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 1)

def test_register_support_user(self):
def test_register_support_user(self) -> None:
user_id = self.get_success(
self.handler.register_user(localpart="user", user_type=UserTypes.SUPPORT)
)
d = self.store.is_support_user(user_id)
self.assertTrue(self.get_success(d))

def test_register_not_support_user(self):
def test_register_not_support_user(self) -> None:
user_id = self.get_success(self.handler.register_user(localpart="user"))
d = self.store.is_support_user(user_id)
self.assertFalse(self.get_success(d))

def test_invalid_user_id_length(self):
def test_invalid_user_id_length(self) -> None:
invalid_user_id = "x" * 256
self.get_failure(
self.handler.register_user(localpart=invalid_user_id), SynapseError
@@ -577,7 +596,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
def test_spam_checker_deny(self):
def test_spam_checker_deny(self) -> None:
"""A spam checker can deny registration, which results in an error."""
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)

@@ -590,7 +609,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
def test_spam_checker_legacy_allow(self):
def test_spam_checker_legacy_allow(self) -> None:
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
check_registration_for_spam callback is correctly called.

@@ -610,7 +629,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
def test_spam_checker_legacy_deny(self):
def test_spam_checker_legacy_deny(self) -> None:
"""Tests that a legacy spam checker implementing the legacy 3-arg version of the
check_registration_for_spam callback is correctly called.

@@ -630,7 +649,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
def test_spam_checker_shadow_ban(self):
def test_spam_checker_shadow_ban(self) -> None:
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
user_id = self.get_success(self.handler.register_user(localpart="user"))

@@ -660,7 +679,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
]
}
)
def test_spam_checker_receives_sso_type(self):
def test_spam_checker_receives_sso_type(self) -> None:
"""Test rejecting registration based on SSO type"""
f = self.get_failure(
self.handler.register_user(localpart="bobflimflob", auth_provider_id="cas"),
@@ -678,8 +697,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)

async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
self,
requester: Requester,
localpart: str,
displayname: Optional[str],
password_hash: Optional[str] = None,
) -> Tuple[str, str]:
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.

@@ -734,13 +757,15 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
"""Tests auto-join on remote rooms."""

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.room_id = "!roomid:remotetest"

async def update_membership(*args, **kwargs):
async def update_membership(*args: Any, **kwargs: Any) -> None:
pass

async def lookup_room_alias(*args, **kwargs):
async def lookup_room_alias(
*args: Any, **kwargs: Any
) -> Tuple[RoomID, List[str]]:
return RoomID.from_string(self.room_id), ["remotetest"]

self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"])
@@ -750,12 +775,12 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler)
return hs

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_registration_handler()
self.store = self.hs.get_datastores().main

@override_config({"auto_join_rooms": ["#room:remotetest"]})
def test_auto_create_auto_join_remote_room(self):
def test_auto_create_auto_join_remote_room(self) -> None:
"""Tests that we don't attempt to create remote rooms, and that we don't attempt
to invite ourselves to rooms we're not in."""



+ 3
- 3
tests/handlers/test_room.py View File

@@ -14,7 +14,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
]

@override_config({"encryption_enabled_by_default_for_room_type": "all"})
def test_encrypted_by_default_config_option_all(self):
def test_encrypted_by_default_config_option_all(self) -> None:
"""Tests that invite-only and non-invite-only rooms have encryption enabled by
default when the config option encryption_enabled_by_default_for_room_type is "all".
"""
@@ -45,7 +45,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})

@override_config({"encryption_enabled_by_default_for_room_type": "invite"})
def test_encrypted_by_default_config_option_invite(self):
def test_encrypted_by_default_config_option_invite(self) -> None:
"""Tests that only new, invite-only rooms have encryption enabled by default when
the config option encryption_enabled_by_default_for_room_type is "invite".
"""
@@ -76,7 +76,7 @@ class EncryptedByDefaultTestCase(unittest.HomeserverTestCase):
)

@override_config({"encryption_enabled_by_default_for_room_type": "off"})
def test_encrypted_by_default_config_option_off(self):
def test_encrypted_by_default_config_option_off(self) -> None:
"""Tests that neither new invite-only nor non-invite-only rooms have encryption
enabled by default when the config option
encryption_enabled_by_default_for_room_type is "off".


+ 44
- 32
tests/handlers/test_room_summary.py View File

@@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from unittest import mock

from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import (
EventContentFields,
@@ -34,11 +35,14 @@ from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock

from tests import unittest


def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0):
def _create_event(
room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0
) -> mock.Mock:
result = mock.Mock(name=room_id)
result.room_id = room_id
result.content = {}
@@ -48,40 +52,40 @@ def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: i
return result


def _order(*events):
def _order(*events: mock.Mock) -> List[mock.Mock]:
return sorted(events, key=_child_events_comparison_key)


class TestSpaceSummarySort(unittest.TestCase):
def test_no_order_last(self):
def test_no_order_last(self) -> None:
"""An event with no ordering is placed behind those with an ordering."""
ev1 = _create_event("!abc:test")
ev2 = _create_event("!xyz:test", "xyz")

self.assertEqual([ev2, ev1], _order(ev1, ev2))

def test_order(self):
def test_order(self) -> None:
"""The ordering should be used."""
ev1 = _create_event("!abc:test", "xyz")
ev2 = _create_event("!xyz:test", "abc")

self.assertEqual([ev2, ev1], _order(ev1, ev2))

def test_order_origin_server_ts(self):
def test_order_origin_server_ts(self) -> None:
"""Origin server is a tie-breaker for ordering."""
ev1 = _create_event("!abc:test", origin_server_ts=10)
ev2 = _create_event("!xyz:test", origin_server_ts=30)

self.assertEqual([ev1, ev2], _order(ev1, ev2))

def test_order_room_id(self):
def test_order_room_id(self) -> None:
"""Room ID is a final tie-breaker for ordering."""
ev1 = _create_event("!abc:test")
ev2 = _create_event("!xyz:test")

self.assertEqual([ev1, ev2], _order(ev1, ev2))

def test_invalid_ordering_type(self):
def test_invalid_ordering_type(self) -> None:
"""Invalid orderings are considered the same as missing."""
ev1 = _create_event("!abc:test", 1)
ev2 = _create_event("!xyz:test", "xyz")
@@ -97,7 +101,7 @@ class TestSpaceSummarySort(unittest.TestCase):
ev1 = _create_event("!abc:test", True)
self.assertEqual([ev2, ev1], _order(ev1, ev2))

def test_invalid_ordering_value(self):
def test_invalid_ordering_value(self) -> None:
"""Invalid orderings are considered the same as missing."""
ev1 = _create_event("!abc:test", "foo\n")
ev2 = _create_event("!xyz:test", "xyz")
@@ -115,7 +119,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor, clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
self.handler = self.hs.get_room_summary_handler()

@@ -223,7 +227,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6)
)

def test_simple_space(self):
def test_simple_space(self) -> None:
"""Test a simple space with a single room."""
# The result should have the space and the room in it, along with a link
# from space -> room.
@@ -234,7 +238,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)

def test_large_space(self):
def test_large_space(self) -> None:
"""Test a space with a large number of rooms."""
rooms = [self.room]
# Make at least 51 rooms that are part of the space.
@@ -260,7 +264,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result["rooms"] += result2["rooms"]
self._assert_hierarchy(result, expected)

def test_visibility(self):
def test_visibility(self) -> None:
"""A user not in a space cannot inspect it."""
user2 = self.register_user("user2", "pass")
token2 = self.login("user2", "pass")
@@ -380,7 +384,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_hierarchy(result2, [(self.space, [self.room])])

def _create_room_with_join_rule(
self, join_rule: str, room_version: Optional[str] = None, **extra_content
self, join_rule: str, room_version: Optional[str] = None, **extra_content: Any
) -> str:
"""Create a room with the given join rule and add it to the space."""
room_id = self.helper.create_room_as(
@@ -403,7 +407,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._add_child(self.space, room_id, self.token)
return room_id

def test_filtering(self):
def test_filtering(self) -> None:
"""
Rooms should be properly filtered to only include rooms the user has access to.
"""
@@ -476,7 +480,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)

def test_complex_space(self):
def test_complex_space(self) -> None:
"""
Create a "complex" space to see how it handles things like loops and subspaces.
"""
@@ -516,7 +520,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)

def test_pagination(self):
def test_pagination(self) -> None:
"""Test simple pagination works."""
room_ids = []
for i in range(1, 10):
@@ -553,7 +557,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._assert_hierarchy(result, expected)
self.assertNotIn("next_batch", result)

def test_invalid_pagination_token(self):
def test_invalid_pagination_token(self) -> None:
"""An invalid pagination token, or changing other parameters, shoudl be rejected."""
room_ids = []
for i in range(1, 10):
@@ -604,7 +608,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
SynapseError,
)

def test_max_depth(self):
def test_max_depth(self) -> None:
"""Create a deep tree to test the max depth against."""
spaces = [self.space]
rooms = [self.room]
@@ -659,7 +663,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
]
self._assert_hierarchy(result, expected)

def test_unknown_room_version(self):
def test_unknown_room_version(self) -> None:
"""
If a room with an unknown room version is encountered it should not cause
the entire summary to skip.
@@ -685,7 +689,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)

def test_fed_complex(self):
def test_fed_complex(self) -> None:
"""
Return data over federation and ensure that it is handled properly.
"""
@@ -722,7 +726,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
"world_readable": True,
}

async def summarize_remote_room_hierarchy(_self, room, suggested_only):
async def summarize_remote_room_hierarchy(
_self: Any, room: Any, suggested_only: bool
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return requested_room_entry, {subroom: child_room}, set()

# Add a room to the space which is on another server.
@@ -744,7 +750,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)

def test_fed_filtering(self):
def test_fed_filtering(self) -> None:
"""
Rooms returned over federation should be properly filtered to only include
rooms the user has access to.
@@ -853,7 +859,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
],
)

async def summarize_remote_room_hierarchy(_self, room, suggested_only):
async def summarize_remote_room_hierarchy(
_self: Any, room: Any, suggested_only: bool
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return subspace_room_entry, dict(children_rooms), set()

# Add a room to the space which is on another server.
@@ -892,7 +900,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)

def test_fed_invited(self):
def test_fed_invited(self) -> None:
"""
A room which the user was invited to should be included in the response.

@@ -915,7 +923,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
},
)

async def summarize_remote_room_hierarchy(_self, room, suggested_only):
async def summarize_remote_room_hierarchy(
_self: Any, room: Any, suggested_only: bool
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return fed_room_entry, {}, set()

# Add a room to the space which is on another server.
@@ -936,7 +946,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)

def test_fed_caching(self):
def test_fed_caching(self) -> None:
"""
Federation `/hierarchy` responses should be cached.
"""
@@ -1023,7 +1033,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor, clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.hs = hs
self.handler = self.hs.get_room_summary_handler()

@@ -1040,12 +1050,12 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
tok=self.token,
)

def test_own_room(self):
def test_own_room(self) -> None:
"""Test a simple room created by the requester."""
result = self.get_success(self.handler.get_room_summary(self.user, self.room))
self.assertEqual(result.get("room_id"), self.room)

def test_visibility(self):
def test_visibility(self) -> None:
"""A user not in a private room cannot get its summary."""
user2 = self.register_user("user2", "pass")
token2 = self.login("user2", "pass")
@@ -1093,7 +1103,7 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.handler.get_room_summary(user2, self.room))
self.assertEqual(result.get("room_id"), self.room)

def test_fed(self):
def test_fed(self) -> None:
"""
Return data over federation and ensure that it is handled properly.
"""
@@ -1105,7 +1115,9 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase):
{"room_id": fed_room, "world_readable": True},
)

async def summarize_remote_room_hierarchy(_self, room, suggested_only):
async def summarize_remote_room_hierarchy(
_self: Any, room: Any, suggested_only: bool
) -> Tuple[Optional[_RoomEntry], Dict[str, JsonDict], Set[str]]:
return requested_room_entry, {}, set()

with mock.patch(


+ 22
- 11
tests/handlers/test_saml.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Set, Tuple
from unittest.mock import Mock

import attr
@@ -20,7 +20,9 @@ import attr
from twisted.test.proto_helpers import MemoryReactor

from synapse.api.errors import RedirectException
from synapse.module_api import ModuleApi
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock

from tests.test_utils import simple_async_mock
@@ -29,6 +31,7 @@ from tests.unittest import HomeserverTestCase, override_config
# Check if we have the dependencies to run the tests.
try:
import saml2.config
import saml2.response
from saml2.sigver import SigverError

has_saml2 = True
@@ -56,31 +59,39 @@ class FakeAuthnResponse:


class TestMappingProvider:
def __init__(self, config, module):
def __init__(self, config: None, module: ModuleApi):
pass

@staticmethod
def parse_config(config):
return
def parse_config(config: JsonDict) -> None:
return None

@staticmethod
def get_saml_attributes(config):
def get_saml_attributes(config: None) -> Tuple[Set[str], Set[str]]:
return {"uid"}, {"displayName"}

def get_remote_user_id(self, saml_response, client_redirect_url):
def get_remote_user_id(
self, saml_response: "saml2.response.AuthnResponse", client_redirect_url: str
) -> str:
return saml_response.ava["uid"]

def saml_response_to_user_attributes(
self, saml_response, failures, client_redirect_url
):
self,
saml_response: "saml2.response.AuthnResponse",
failures: int,
client_redirect_url: str,
) -> dict:
localpart = saml_response.ava["username"] + (str(failures) if failures else "")
return {"mxid_localpart": localpart, "displayname": None}


class TestRedirectMappingProvider(TestMappingProvider):
def saml_response_to_user_attributes(
self, saml_response, failures, client_redirect_url
):
self,
saml_response: "saml2.response.AuthnResponse",
failures: int,
client_redirect_url: str,
) -> dict:
raise RedirectException(b"https://custom-saml-redirect/")


@@ -347,7 +358,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
)


def _mock_request():
def _mock_request() -> Mock:
"""Returns a mock which will stand in as a SynapseRequest"""
mock = Mock(
spec=[


+ 18
- 11
tests/handlers/test_send_email.py View File

@@ -13,7 +13,7 @@
# limitations under the License.


from typing import List, Tuple
from typing import Callable, List, Tuple

from zope.interface import implementer

@@ -28,20 +28,27 @@ from tests.unittest import HomeserverTestCase, override_config

@implementer(interfaces.IMessageDelivery)
class _DummyMessageDelivery:
def __init__(self):
def __init__(self) -> None:
# (recipient, message) tuples
self.messages: List[Tuple[smtp.Address, bytes]] = []

def receivedHeader(self, helo, origin, recipients):
def receivedHeader(
self,
helo: Tuple[bytes, bytes],
origin: smtp.Address,
recipients: List[smtp.User],
) -> None:
return None

def validateFrom(self, helo, origin):
def validateFrom(
self, helo: Tuple[bytes, bytes], origin: smtp.Address
) -> smtp.Address:
return origin

def record_message(self, recipient: smtp.Address, message: bytes):
def record_message(self, recipient: smtp.Address, message: bytes) -> None:
self.messages.append((recipient, message))

def validateTo(self, user: smtp.User):
def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]:
return lambda: _DummyMessage(self, user)


@@ -56,20 +63,20 @@ class _DummyMessage:
self._user = user
self._buffer: List[bytes] = []

def lineReceived(self, line):
def lineReceived(self, line: bytes) -> None:
self._buffer.append(line)

def eomReceived(self):
def eomReceived(self) -> "defer.Deferred[bytes]":
message = b"\n".join(self._buffer) + b"\n"
self._delivery.record_message(self._user.dest, message)
return defer.succeed(b"saved")

def connectionLost(self):
def connectionLost(self) -> None:
pass


class SendEmailHandlerTestCase(HomeserverTestCase):
def test_send_email(self):
def test_send_email(self) -> None:
"""Happy-path test that we can send email to a non-TLS server."""
h = self.hs.get_send_email_handler()
d = ensureDeferred(
@@ -119,7 +126,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
},
}
)
def test_send_email_force_tls(self):
def test_send_email_force_tls(self) -> None:
"""Happy-path test that we can send email to an Implicit TLS server."""
h = self.hs.get_send_email_handler()
d = ensureDeferred(


+ 53
- 21
tests/handlers/test_stats.py View File

@@ -12,9 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional

from twisted.test.proto_helpers import MemoryReactor

from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main import stats
from synapse.util import Clock

from tests import unittest

@@ -32,11 +38,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
login.register_servlets,
]

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.handler = self.hs.get_stats_handler()

def _add_background_updates(self):
def _add_background_updates(self) -> None:
"""
Add the background updates we need to run.
"""
@@ -63,12 +69,14 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)

async def get_all_room_state(self):
async def get_all_room_state(self) -> List[Dict[str, Any]]:
return await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)

def _get_current_stats(self, stats_type, stat_id):
def _get_current_stats(
self, stats_type: str, stat_id: str
) -> Optional[Dict[str, Any]]:
table, id_col = stats.TYPE_TO_TABLE[stats_type]

cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type])
@@ -82,13 +90,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
)

def _perform_background_initial_update(self):
def _perform_background_initial_update(self) -> None:
# Do the initial population of the stats via the background update
self._add_background_updates()

self.wait_for_background_updates()

def test_initial_room(self):
def test_initial_room(self) -> None:
"""
The background updates will build the table from scratch.
"""
@@ -125,7 +133,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(len(r), 1)
self.assertEqual(r[0]["topic"], "foo")

def test_create_user(self):
def test_create_user(self) -> None:
"""
When we create a user, it should have statistics already ready.
"""
@@ -134,12 +142,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):

u1stats = self._get_current_stats("user", u1)

self.assertIsNotNone(u1stats)
assert u1stats is not None

# not in any rooms by default
self.assertEqual(u1stats["joined_rooms"], 0)

def test_create_room(self):
def test_create_room(self) -> None:
"""
When we create a room, it should have statistics already ready.
"""
@@ -153,8 +161,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r2 = self.helper.create_room_as(u1, tok=u1token, is_public=False)
r2stats = self._get_current_stats("room", r2)

self.assertIsNotNone(r1stats)
self.assertIsNotNone(r2stats)
assert r1stats is not None
assert r2stats is not None

self.assertEqual(
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
@@ -171,7 +179,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(r2stats["invited_members"], 0)
self.assertEqual(r2stats["banned_members"], 0)

def test_updating_profile_information_does_not_increase_joined_members_count(self):
def test_updating_profile_information_does_not_increase_joined_members_count(
self,
) -> None:
"""
Check that the joined_members count does not increase when a user changes their
profile information (which is done by sending another join membership event into
@@ -186,6 +196,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):

# Get the current room stats
r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None

# Send a profile update into the room
new_profile = {"displayname": "bob"}
@@ -195,6 +206,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):

# Get the new room stats
r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None

# Ensure that the user count did not changed
self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
@@ -202,7 +214,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
)

def test_send_state_event_nonoverwriting(self):
def test_send_state_event_nonoverwriting(self) -> None:
"""
When we send a non-overwriting state event, it increments current_state_events
"""
@@ -218,19 +230,21 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)

r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None

self.helper.send_state(
r1, "cat.hissing", {"value": False}, tok=u1token, state_key="moggy"
)

r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None

self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
1,
)

def test_join_first_time(self):
def test_join_first_time(self) -> None:
"""
When a user joins a room for the first time, current_state_events and
joined_members should increase by exactly 1.
@@ -246,10 +260,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
u2token = self.login("u2", "pass")

r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None

self.helper.join(r1, u2, tok=u2token)

r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None

self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -259,7 +275,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], 1
)

def test_join_after_leave(self):
def test_join_after_leave(self) -> None:
"""
When a user joins a room after being previously left,
joined_members should increase by exactly 1.
@@ -280,10 +296,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.leave(r1, u2, tok=u2token)

r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None

self.helper.join(r1, u2, tok=u2token)

r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None

self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -296,7 +314,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["left_members"] - r1stats_ante["left_members"], -1
)

def test_invited(self):
def test_invited(self) -> None:
"""
When a user invites another user, current_state_events and
invited_members should increase by exactly 1.
@@ -311,10 +329,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
u2 = self.register_user("u2", "pass")

r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None

self.helper.invite(r1, u1, u2, tok=u1token)

r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None

self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -324,7 +344,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["invited_members"] - r1stats_ante["invited_members"], +1
)

def test_join_after_invite(self):
def test_join_after_invite(self) -> None:
"""
When a user joins a room after being invited and
joined_members should increase by exactly 1.
@@ -344,10 +364,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.invite(r1, u1, u2, tok=u1token)

r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None

self.helper.join(r1, u2, tok=u2token)

r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None

self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -360,7 +382,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["invited_members"] - r1stats_ante["invited_members"], -1
)

def test_left(self):
def test_left(self) -> None:
"""
When a user leaves a room after joining and
left_members should increase by exactly 1.
@@ -380,10 +402,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.join(r1, u2, tok=u2token)

r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None

self.helper.leave(r1, u2, tok=u2token)

r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None

self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -396,7 +420,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
)

def test_banned(self):
def test_banned(self) -> None:
"""
When a user is banned from a room after joining and
left_members should increase by exactly 1.
@@ -416,10 +440,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.helper.join(r1, u2, tok=u2token)

r1stats_ante = self._get_current_stats("room", r1)
assert r1stats_ante is not None

self.helper.change_membership(r1, u1, u2, "ban", tok=u1token)

r1stats_post = self._get_current_stats("room", r1)
assert r1stats_post is not None

self.assertEqual(
r1stats_post["current_state_events"] - r1stats_ante["current_state_events"],
@@ -432,7 +458,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats_post["joined_members"] - r1stats_ante["joined_members"], -1
)

def test_initial_background_update(self):
def test_initial_background_update(self) -> None:
"""
Test that statistics can be generated by the initial background update
handler.
@@ -462,6 +488,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
r1stats = self._get_current_stats("room", r1)
u1stats = self._get_current_stats("user", u1)

assert r1stats is not None
assert u1stats is not None

self.assertEqual(r1stats["joined_members"], 1)
self.assertEqual(
r1stats["current_state_events"], EXPT_NUM_STATE_EVTS_IN_FRESH_PUBLIC_ROOM
@@ -469,7 +498,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):

self.assertEqual(u1stats["joined_rooms"], 1)

def test_incomplete_stats(self):
def test_incomplete_stats(self) -> None:
"""
This tests that we track incomplete statistics.

@@ -533,8 +562,11 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.wait_for_background_updates()

r1stats_complete = self._get_current_stats("room", r1)
assert r1stats_complete is not None
u1stats_complete = self._get_current_stats("user", u1)
assert u1stats_complete is not None
u2stats_complete = self._get_current_stats("user", u2)
assert u2stats_complete is not None

# now we make our assertions



+ 7
- 4
tests/handlers/test_sync.py View File

@@ -14,6 +14,8 @@
from typing import Optional
from unittest.mock import MagicMock, Mock, patch

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import Filtering
@@ -23,6 +25,7 @@ from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util import Clock

import tests.unittest
import tests.utils
@@ -39,7 +42,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
room.register_servlets,
]

def prepare(self, reactor, clock, hs: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastores().main

@@ -47,7 +50,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth_blocking()

def test_wait_for_sync_for_user_auth_blocking(self):
def test_wait_for_sync_for_user_auth_blocking(self) -> None:
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = generate_sync_config(user_id1)
@@ -82,7 +85,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)

def test_unknown_room_version(self):
def test_unknown_room_version(self) -> None:
"""
A room with an unknown room version should not break sync (and should be excluded).
"""
@@ -186,7 +189,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertNotIn(invite_room, [r.room_id for r in result.invited])
self.assertNotIn(knock_room, [r.room_id for r in result.knocked])

def test_ban_wins_race_with_join(self):
def test_ban_wins_race_with_join(self) -> None:
"""Rooms shouldn't appear under "joined" if a join loses a race to a ban.

A complicated edge case. Imagine the following scenario:


Loading…
Cancel
Save