Adds a return type to HomeServerTestCase.make_homeserver and deal with any variables which are no longer Any.tags/v1.78.0rc1
@@ -0,0 +1 @@ | |||||
Improve type hints. |
@@ -56,9 +56,6 @@ disallow_untyped_defs = False | |||||
[mypy-synapse.storage.database] | [mypy-synapse.storage.database] | ||||
disallow_untyped_defs = False | disallow_untyped_defs = False | ||||
[mypy-tests.unittest] | |||||
disallow_untyped_defs = False | |||||
[mypy-tests.util.caches.test_descriptors] | [mypy-tests.util.caches.test_descriptors] | ||||
disallow_untyped_defs = False | disallow_untyped_defs = False | ||||
@@ -67,7 +67,9 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): | |||||
} | } | ||||
# Listen with the config | # Listen with the config | ||||
self.hs._listen_http(parse_listener_def(0, config)) | |||||
hs = self.hs | |||||
assert isinstance(hs, GenericWorkerServer) | |||||
hs._listen_http(parse_listener_def(0, config)) | |||||
# Grab the resource from the site that was told to listen | # Grab the resource from the site that was told to listen | ||||
site = self.reactor.tcpServers[0][1] | site = self.reactor.tcpServers[0][1] | ||||
@@ -115,7 +117,9 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): | |||||
} | } | ||||
# Listen with the config | # Listen with the config | ||||
self.hs._listener_http(self.hs.config, parse_listener_def(0, config)) | |||||
hs = self.hs | |||||
assert isinstance(hs, SynapseHomeServer) | |||||
hs._listener_http(self.hs.config, parse_listener_def(0, config)) | |||||
# Grab the resource from the site that was told to listen | # Grab the resource from the site that was told to listen | ||||
site = self.reactor.tcpServers[0][1] | site = self.reactor.tcpServers[0][1] | ||||
@@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||||
key1 = signedjson.key.generate_signing_key("1") | key1 = signedjson.key.generate_signing_key("1") | ||||
r = self.hs.get_datastores().main.store_server_verify_keys( | r = self.hs.get_datastores().main.store_server_verify_keys( | ||||
"server9", | "server9", | ||||
time.time() * 1000, | |||||
int(time.time() * 1000), | |||||
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], | [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], | ||||
) | ) | ||||
self.get_success(r) | self.get_success(r) | ||||
@@ -287,7 +287,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||||
key1 = signedjson.key.generate_signing_key("1") | key1 = signedjson.key.generate_signing_key("1") | ||||
r = self.hs.get_datastores().main.store_server_verify_keys( | r = self.hs.get_datastores().main.store_server_verify_keys( | ||||
"server9", | "server9", | ||||
time.time() * 1000, | |||||
int(time.time() * 1000), | |||||
# None is not a valid value in FetchKeyResult, but we're abusing this | # None is not a valid value in FetchKeyResult, but we're abusing this | ||||
# API to insert null values into the database. The nulls get converted | # API to insert null values into the database. The nulls get converted | ||||
# to 0 when fetched in KeyStore.get_server_verify_keys. | # to 0 when fetched in KeyStore.get_server_verify_keys. | ||||
@@ -466,9 +466,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): | |||||
key_json = self.get_success( | key_json = self.get_success( | ||||
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) | self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) | ||||
) | ) | ||||
res = key_json[lookup_triplet] | |||||
self.assertEqual(len(res), 1) | |||||
res = res[0] | |||||
res_keys = key_json[lookup_triplet] | |||||
self.assertEqual(len(res_keys), 1) | |||||
res = res_keys[0] | |||||
self.assertEqual(res["key_id"], testverifykey_id) | self.assertEqual(res["key_id"], testverifykey_id) | ||||
self.assertEqual(res["from_server"], SERVER_NAME) | self.assertEqual(res["from_server"], SERVER_NAME) | ||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) | self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) | ||||
@@ -584,9 +584,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): | |||||
key_json = self.get_success( | key_json = self.get_success( | ||||
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) | self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) | ||||
) | ) | ||||
res = key_json[lookup_triplet] | |||||
self.assertEqual(len(res), 1) | |||||
res = res[0] | |||||
res_keys = key_json[lookup_triplet] | |||||
self.assertEqual(len(res_keys), 1) | |||||
res = res_keys[0] | |||||
self.assertEqual(res["key_id"], testverifykey_id) | self.assertEqual(res["key_id"], testverifykey_id) | ||||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) | self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) | ||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) | self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) | ||||
@@ -705,9 +705,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): | |||||
key_json = self.get_success( | key_json = self.get_success( | ||||
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) | self.hs.get_datastores().main.get_server_keys_json([lookup_triplet]) | ||||
) | ) | ||||
res = key_json[lookup_triplet] | |||||
self.assertEqual(len(res), 1) | |||||
res = res[0] | |||||
res_keys = key_json[lookup_triplet] | |||||
self.assertEqual(len(res_keys), 1) | |||||
res = res_keys[0] | |||||
self.assertEqual(res["key_id"], testverifykey_id) | self.assertEqual(res["key_id"], testverifykey_id) | ||||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) | self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) | ||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) | self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) | ||||
@@ -156,11 +156,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): | |||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
# Mock out the calls over federation. | # Mock out the calls over federation. | ||||
fed_transport_client = Mock(spec=["send_transaction"]) | |||||
fed_transport_client.send_transaction = simple_async_mock({}) | |||||
self.fed_transport_client = Mock(spec=["send_transaction"]) | |||||
self.fed_transport_client.send_transaction = simple_async_mock({}) | |||||
hs = self.setup_test_homeserver( | hs = self.setup_test_homeserver( | ||||
federation_transport_client=fed_transport_client, | |||||
federation_transport_client=self.fed_transport_client, | |||||
) | ) | ||||
load_legacy_presence_router(hs) | load_legacy_presence_router(hs) | ||||
@@ -422,7 +422,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): | |||||
# | # | ||||
# Thus we reset the mock, and try sending all online local user | # Thus we reset the mock, and try sending all online local user | ||||
# presence again | # presence again | ||||
self.hs.get_federation_transport_client().send_transaction.reset_mock() | |||||
self.fed_transport_client.send_transaction.reset_mock() | |||||
# Broadcast local user online presence | # Broadcast local user online presence | ||||
self.get_success( | self.get_success( | ||||
@@ -447,9 +447,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): | |||||
} | } | ||||
found_users = set() | found_users = set() | ||||
calls = ( | |||||
self.hs.get_federation_transport_client().send_transaction.call_args_list | |||||
) | |||||
calls = self.fed_transport_client.send_transaction.call_args_list | |||||
for call in calls: | for call in calls: | ||||
call_args = call[0] | call_args = call[0] | ||||
federation_transaction: Transaction = call_args[0] | federation_transaction: Transaction = call_args[0] | ||||
@@ -17,7 +17,7 @@ from unittest.mock import Mock | |||||
from synapse.api.errors import Codes, SynapseError | from synapse.api.errors import Codes, SynapseError | ||||
from synapse.rest import admin | from synapse.rest import admin | ||||
from synapse.rest.client import login, room | from synapse.rest.client import login, room | ||||
from synapse.types import JsonDict, UserID | |||||
from synapse.types import JsonDict, UserID, create_requester | |||||
from tests import unittest | from tests import unittest | ||||
from tests.test_utils import make_awaitable | from tests.test_utils import make_awaitable | ||||
@@ -56,7 +56,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): | |||||
# Artificially raise the complexity | # Artificially raise the complexity | ||||
store = self.hs.get_datastores().main | store = self.hs.get_datastores().main | ||||
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) | |||||
async def get_current_state_event_counts(room_id: str) -> int: | |||||
return int(500 * 1.23) | |||||
store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] | |||||
# Get the room complexity again -- make sure it's our artificial value | # Get the room complexity again -- make sure it's our artificial value | ||||
channel = self.make_signed_federation_request( | channel = self.make_signed_federation_request( | ||||
@@ -75,12 +79,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): | |||||
# Mock out some things, because we don't want to test the whole join | # Mock out some things, because we don't want to test the whole join | ||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | ||||
handler.federation_handler.do_invite_join = Mock( | |||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(("", 1)) | return_value=make_awaitable(("", 1)) | ||||
) | ) | ||||
d = handler._remote_join( | d = handler._remote_join( | ||||
None, | |||||
create_requester(u1), | |||||
["other.example.com"], | ["other.example.com"], | ||||
"roomid", | "roomid", | ||||
UserID.from_string(u1), | UserID.from_string(u1), | ||||
@@ -106,12 +110,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): | |||||
# Mock out some things, because we don't want to test the whole join | # Mock out some things, because we don't want to test the whole join | ||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | ||||
handler.federation_handler.do_invite_join = Mock( | |||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(("", 1)) | return_value=make_awaitable(("", 1)) | ||||
) | ) | ||||
d = handler._remote_join( | d = handler._remote_join( | ||||
None, | |||||
create_requester(u1), | |||||
["other.example.com"], | ["other.example.com"], | ||||
"roomid", | "roomid", | ||||
UserID.from_string(u1), | UserID.from_string(u1), | ||||
@@ -144,17 +148,18 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): | |||||
# Mock out some things, because we don't want to test the whole join | # Mock out some things, because we don't want to test the whole join | ||||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) | fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) | ||||
handler.federation_handler.do_invite_join = Mock( | |||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(("", 1)) | return_value=make_awaitable(("", 1)) | ||||
) | ) | ||||
# Artificially raise the complexity | # Artificially raise the complexity | ||||
self.hs.get_datastores().main.get_current_state_event_counts = ( | |||||
lambda x: make_awaitable(600) | |||||
) | |||||
async def get_current_state_event_counts(room_id: str) -> int: | |||||
return 600 | |||||
self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] | |||||
d = handler._remote_join( | d = handler._remote_join( | ||||
None, | |||||
create_requester(u1), | |||||
["other.example.com"], | ["other.example.com"], | ||||
room_1, | room_1, | ||||
UserID.from_string(u1), | UserID.from_string(u1), | ||||
@@ -200,12 +205,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): | |||||
# Mock out some things, because we don't want to test the whole join | # Mock out some things, because we don't want to test the whole join | ||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | ||||
handler.federation_handler.do_invite_join = Mock( | |||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(("", 1)) | return_value=make_awaitable(("", 1)) | ||||
) | ) | ||||
d = handler._remote_join( | d = handler._remote_join( | ||||
None, | |||||
create_requester(u1), | |||||
["other.example.com"], | ["other.example.com"], | ||||
"roomid", | "roomid", | ||||
UserID.from_string(u1), | UserID.from_string(u1), | ||||
@@ -230,12 +235,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): | |||||
# Mock out some things, because we don't want to test the whole join | # Mock out some things, because we don't want to test the whole join | ||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) | ||||
handler.federation_handler.do_invite_join = Mock( | |||||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(("", 1)) | return_value=make_awaitable(("", 1)) | ||||
) | ) | ||||
d = handler._remote_join( | d = handler._remote_join( | ||||
None, | |||||
create_requester(u1), | |||||
["other.example.com"], | ["other.example.com"], | ||||
"roomid", | "roomid", | ||||
UserID.from_string(u1), | UserID.from_string(u1), | ||||
@@ -5,7 +5,11 @@ from twisted.test.proto_helpers import MemoryReactor | |||||
from synapse.api.constants import EventTypes | from synapse.api.constants import EventTypes | ||||
from synapse.events import EventBase | from synapse.events import EventBase | ||||
from synapse.federation.sender import PerDestinationQueue, TransactionManager | |||||
from synapse.federation.sender import ( | |||||
FederationSender, | |||||
PerDestinationQueue, | |||||
TransactionManager, | |||||
) | |||||
from synapse.federation.units import Edu, Transaction | from synapse.federation.units import Edu, Transaction | ||||
from synapse.rest import admin | from synapse.rest import admin | ||||
from synapse.rest.client import login, room | from synapse.rest.client import login, room | ||||
@@ -33,8 +37,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): | |||||
] | ] | ||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
self.federation_transport_client = Mock(spec=["send_transaction"]) | |||||
return self.setup_test_homeserver( | return self.setup_test_homeserver( | ||||
federation_transport_client=Mock(spec=["send_transaction"]), | |||||
federation_transport_client=self.federation_transport_client, | |||||
) | ) | ||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
@@ -52,10 +57,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): | |||||
self.pdus: List[JsonDict] = [] | self.pdus: List[JsonDict] = [] | ||||
self.failed_pdus: List[JsonDict] = [] | self.failed_pdus: List[JsonDict] = [] | ||||
self.is_online = True | self.is_online = True | ||||
self.hs.get_federation_transport_client().send_transaction.side_effect = ( | |||||
self.federation_transport_client.send_transaction.side_effect = ( | |||||
self.record_transaction | self.record_transaction | ||||
) | ) | ||||
federation_sender = hs.get_federation_sender() | |||||
assert isinstance(federation_sender, FederationSender) | |||||
self.federation_sender = federation_sender | |||||
def default_config(self) -> JsonDict: | def default_config(self) -> JsonDict: | ||||
config = super().default_config() | config = super().default_config() | ||||
config["federation_sender_instances"] = None | config["federation_sender_instances"] = None | ||||
@@ -229,11 +238,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): | |||||
# let's delete the federation transmission queue | # let's delete the federation transmission queue | ||||
# (this pretends we are starting up fresh.) | # (this pretends we are starting up fresh.) | ||||
self.assertFalse( | self.assertFalse( | ||||
self.hs.get_federation_sender() | |||||
._per_destination_queues["host2"] | |||||
.transmission_loop_running | |||||
self.federation_sender._per_destination_queues[ | |||||
"host2" | |||||
].transmission_loop_running | |||||
) | ) | ||||
del self.hs.get_federation_sender()._per_destination_queues["host2"] | |||||
del self.federation_sender._per_destination_queues["host2"] | |||||
# let's also clear any backoffs | # let's also clear any backoffs | ||||
self.get_success( | self.get_success( | ||||
@@ -322,6 +331,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): | |||||
# also fetch event 5 so we know its last_successful_stream_ordering later | # also fetch event 5 so we know its last_successful_stream_ordering later | ||||
event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5)) | event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5)) | ||||
assert event_2.internal_metadata.stream_ordering is not None | |||||
self.get_success( | self.get_success( | ||||
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( | self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( | ||||
"host2", event_2.internal_metadata.stream_ordering | "host2", event_2.internal_metadata.stream_ordering | ||||
@@ -425,15 +435,16 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): | |||||
def wake_destination_track(destination: str) -> None: | def wake_destination_track(destination: str) -> None: | ||||
woken.append(destination) | woken.append(destination) | ||||
self.hs.get_federation_sender().wake_destination = wake_destination_track | |||||
self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment] | |||||
# cancel the pre-existing timer for _wake_destinations_needing_catchup | # cancel the pre-existing timer for _wake_destinations_needing_catchup | ||||
# this is because we are calling it manually rather than waiting for it | # this is because we are calling it manually rather than waiting for it | ||||
# to be called automatically | # to be called automatically | ||||
self.hs.get_federation_sender()._catchup_after_startup_timer.cancel() | |||||
assert self.federation_sender._catchup_after_startup_timer is not None | |||||
self.federation_sender._catchup_after_startup_timer.cancel() | |||||
self.get_success( | self.get_success( | ||||
self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0 | |||||
self.federation_sender._wake_destinations_needing_catchup(), by=5.0 | |||||
) | ) | ||||
# ASSERT (_wake_destinations_needing_catchup): | # ASSERT (_wake_destinations_needing_catchup): | ||||
@@ -475,6 +486,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): | |||||
) | ) | ||||
) | ) | ||||
assert event_1.internal_metadata.stream_ordering is not None | |||||
self.get_success( | self.get_success( | ||||
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( | self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( | ||||
"host2", event_1.internal_metadata.stream_ordering | "host2", event_1.internal_metadata.stream_ordering | ||||
@@ -178,7 +178,7 @@ class FederationClientTest(FederatingHomeserverTestCase): | |||||
RoomVersions.V9, | RoomVersions.V9, | ||||
) | ) | ||||
) | ) | ||||
self.assertIsNotNone(pulled_pdu_info2) | |||||
assert pulled_pdu_info2 is not None | |||||
remote_pdu2 = pulled_pdu_info2.pdu | remote_pdu2 = pulled_pdu_info2.pdu | ||||
# Sanity check that we are working against the same event | # Sanity check that we are working against the same event | ||||
@@ -226,7 +226,7 @@ class FederationClientTest(FederatingHomeserverTestCase): | |||||
RoomVersions.V9, | RoomVersions.V9, | ||||
) | ) | ||||
) | ) | ||||
self.assertIsNotNone(pulled_pdu_info) | |||||
assert pulled_pdu_info is not None | |||||
remote_pdu = pulled_pdu_info.pdu | remote_pdu = pulled_pdu_info.pdu | ||||
# check the right call got made to the agent | # check the right call got made to the agent | ||||
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor | |||||
from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms | from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms | ||||
from synapse.federation.units import Transaction | from synapse.federation.units import Transaction | ||||
from synapse.handlers.device import DeviceHandler | |||||
from synapse.rest import admin | from synapse.rest import admin | ||||
from synapse.rest.client import login | from synapse.rest.client import login | ||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
@@ -41,8 +42,9 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): | |||||
""" | """ | ||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
self.federation_transport_client = Mock(spec=["send_transaction"]) | |||||
hs = self.setup_test_homeserver( | hs = self.setup_test_homeserver( | ||||
federation_transport_client=Mock(spec=["send_transaction"]), | |||||
federation_transport_client=self.federation_transport_client, | |||||
) | ) | ||||
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] | hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] | ||||
@@ -61,9 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): | |||||
return config | return config | ||||
def test_send_receipts(self) -> None: | def test_send_receipts(self) -> None: | ||||
mock_send_transaction = ( | |||||
self.hs.get_federation_transport_client().send_transaction | |||||
) | |||||
mock_send_transaction = self.federation_transport_client.send_transaction | |||||
mock_send_transaction.return_value = make_awaitable({}) | mock_send_transaction.return_value = make_awaitable({}) | ||||
sender = self.hs.get_federation_sender() | sender = self.hs.get_federation_sender() | ||||
@@ -103,9 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): | |||||
) | ) | ||||
def test_send_receipts_thread(self) -> None: | def test_send_receipts_thread(self) -> None: | ||||
mock_send_transaction = ( | |||||
self.hs.get_federation_transport_client().send_transaction | |||||
) | |||||
mock_send_transaction = self.federation_transport_client.send_transaction | |||||
mock_send_transaction.return_value = make_awaitable({}) | mock_send_transaction.return_value = make_awaitable({}) | ||||
# Create receipts for: | # Create receipts for: | ||||
@@ -181,9 +179,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): | |||||
def test_send_receipts_with_backoff(self) -> None: | def test_send_receipts_with_backoff(self) -> None: | ||||
"""Send two receipts in quick succession; the second should be flushed, but | """Send two receipts in quick succession; the second should be flushed, but | ||||
only after 20ms""" | only after 20ms""" | ||||
mock_send_transaction = ( | |||||
self.hs.get_federation_transport_client().send_transaction | |||||
) | |||||
mock_send_transaction = self.federation_transport_client.send_transaction | |||||
mock_send_transaction.return_value = make_awaitable({}) | mock_send_transaction.return_value = make_awaitable({}) | ||||
sender = self.hs.get_federation_sender() | sender = self.hs.get_federation_sender() | ||||
@@ -277,10 +273,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||||
] | ] | ||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
self.federation_transport_client = Mock( | |||||
spec=["send_transaction", "query_user_devices"] | |||||
) | |||||
return self.setup_test_homeserver( | return self.setup_test_homeserver( | ||||
federation_transport_client=Mock( | |||||
spec=["send_transaction", "query_user_devices"] | |||||
), | |||||
federation_transport_client=self.federation_transport_client, | |||||
) | ) | ||||
def default_config(self) -> JsonDict: | def default_config(self) -> JsonDict: | ||||
@@ -310,9 +307,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||||
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment] | hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment] | ||||
device_handler = hs.get_device_handler() | |||||
assert isinstance(device_handler, DeviceHandler) | |||||
self.device_handler = device_handler | |||||
# whenever send_transaction is called, record the edu data | # whenever send_transaction is called, record the edu data | ||||
self.edus: List[JsonDict] = [] | self.edus: List[JsonDict] = [] | ||||
self.hs.get_federation_transport_client().send_transaction.side_effect = ( | |||||
self.federation_transport_client.send_transaction.side_effect = ( | |||||
self.record_transaction | self.record_transaction | ||||
) | ) | ||||
@@ -353,7 +354,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||||
# Send the server a device list EDU for the other user, this will cause | # Send the server a device list EDU for the other user, this will cause | ||||
# it to try and resync the device lists. | # it to try and resync the device lists. | ||||
self.hs.get_federation_transport_client().query_user_devices.return_value = ( | |||||
self.federation_transport_client.query_user_devices.return_value = ( | |||||
make_awaitable( | make_awaitable( | ||||
{ | { | ||||
"stream_id": "1", | "stream_id": "1", | ||||
@@ -364,7 +365,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||||
) | ) | ||||
self.get_success( | self.get_success( | ||||
self.hs.get_device_handler().device_list_updater.incoming_device_list_update( | |||||
self.device_handler.device_list_updater.incoming_device_list_update( | |||||
"host2", | "host2", | ||||
{ | { | ||||
"user_id": "@user2:host2", | "user_id": "@user2:host2", | ||||
@@ -507,9 +508,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||||
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id) | stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id) | ||||
# delete them again | # delete them again | ||||
self.get_success( | |||||
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) | |||||
) | |||||
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"])) | |||||
# We queue up device list updates to be sent over federation, so we | # We queue up device list updates to be sent over federation, so we | ||||
# advance to clear the queue. | # advance to clear the queue. | ||||
@@ -533,7 +532,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||||
"""If the destination server is unreachable, all the updates should get sent on | """If the destination server is unreachable, all the updates should get sent on | ||||
recovery | recovery | ||||
""" | """ | ||||
mock_send_txn = self.hs.get_federation_transport_client().send_transaction | |||||
mock_send_txn = self.federation_transport_client.send_transaction | |||||
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) | mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) | ||||
# create devices | # create devices | ||||
@@ -543,9 +542,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||||
self.login("user", "pass", device_id="D3") | self.login("user", "pass", device_id="D3") | ||||
# delete them again | # delete them again | ||||
self.get_success( | |||||
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) | |||||
) | |||||
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"])) | |||||
# We queue up device list updates to be sent over federation, so we | # We queue up device list updates to be sent over federation, so we | ||||
# advance to clear the queue. | # advance to clear the queue. | ||||
@@ -580,7 +577,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||||
This case tests the behaviour when the server has never been reachable. | This case tests the behaviour when the server has never been reachable. | ||||
""" | """ | ||||
mock_send_txn = self.hs.get_federation_transport_client().send_transaction | |||||
mock_send_txn = self.federation_transport_client.send_transaction | |||||
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) | mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) | ||||
# create devices | # create devices | ||||
@@ -590,9 +587,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||||
self.login("user", "pass", device_id="D3") | self.login("user", "pass", device_id="D3") | ||||
# delete them again | # delete them again | ||||
self.get_success( | |||||
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) | |||||
) | |||||
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"])) | |||||
# We queue up device list updates to be sent over federation, so we | # We queue up device list updates to be sent over federation, so we | ||||
# advance to clear the queue. | # advance to clear the queue. | ||||
@@ -640,7 +635,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||||
self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) | self.check_device_update_edu(self.edus.pop(0), u1, "D1", None) | ||||
# now the server goes offline | # now the server goes offline | ||||
mock_send_txn = self.hs.get_federation_transport_client().send_transaction | |||||
mock_send_txn = self.federation_transport_client.send_transaction | |||||
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) | mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) | ||||
self.login("user", "pass", device_id="D2") | self.login("user", "pass", device_id="D2") | ||||
@@ -651,9 +646,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||||
self.reactor.advance(1) | self.reactor.advance(1) | ||||
# delete them again | # delete them again | ||||
self.get_success( | |||||
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"]) | |||||
) | |||||
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"])) | |||||
self.assertGreaterEqual(mock_send_txn.call_count, 3) | self.assertGreaterEqual(mock_send_txn.call_count, 3) | ||||
@@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) | |||||
# Mock out application services, and allow defining our own in tests | # Mock out application services, and allow defining our own in tests | ||||
self._services: List[ApplicationService] = [] | self._services: List[ApplicationService] = [] | ||||
self.hs.get_datastores().main.get_app_services = Mock( | |||||
self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] | |||||
return_value=self._services | return_value=self._services | ||||
) | ) | ||||
@@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() | |||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
cas_response = CasResponse("test_user", {}) | cas_response = CasResponse("test_user", {}) | ||||
request = _mock_request() | request = _mock_request() | ||||
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() | |||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
# Map a user via SSO. | # Map a user via SSO. | ||||
cas_response = CasResponse("test_user", {}) | cas_response = CasResponse("test_user", {}) | ||||
@@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() | |||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
cas_response = CasResponse("föö", {}) | cas_response = CasResponse("föö", {}) | ||||
request = _mock_request() | request = _mock_request() | ||||
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() | |||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
# The response doesn't have the proper userGroup or department. | # The response doesn't have the proper userGroup or department. | ||||
cas_response = CasResponse("test_user", {}) | cas_response = CasResponse("test_user", {}) | ||||
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor | |||||
from synapse.api.constants import RoomEncryptionAlgorithms | from synapse.api.constants import RoomEncryptionAlgorithms | ||||
from synapse.api.errors import Codes, SynapseError | from synapse.api.errors import Codes, SynapseError | ||||
from synapse.handlers.device import DeviceHandler | |||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
from synapse.types import JsonDict | from synapse.types import JsonDict | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
@@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
# we should now have an unused alg1 key | # we should now have an unused alg1 key | ||||
res = self.get_success( | |||||
fallback_res = self.get_success( | |||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | ||||
) | ) | ||||
self.assertEqual(res, ["alg1"]) | |||||
self.assertEqual(fallback_res, ["alg1"]) | |||||
# claiming an OTK when no OTKs are available should return the fallback | # claiming an OTK when no OTKs are available should return the fallback | ||||
# key | # key | ||||
res = self.get_success( | |||||
claim_res = self.get_success( | |||||
self.handler.claim_one_time_keys( | self.handler.claim_one_time_keys( | ||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None | {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None | ||||
) | ) | ||||
) | ) | ||||
self.assertEqual( | self.assertEqual( | ||||
res, | |||||
claim_res, | |||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, | {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, | ||||
) | ) | ||||
# we shouldn't have any unused fallback keys again | # we shouldn't have any unused fallback keys again | ||||
res = self.get_success( | |||||
unused_res = self.get_success( | |||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | ||||
) | ) | ||||
self.assertEqual(res, []) | |||||
self.assertEqual(unused_res, []) | |||||
# claiming an OTK again should return the same fallback key | # claiming an OTK again should return the same fallback key | ||||
res = self.get_success( | |||||
claim_res = self.get_success( | |||||
self.handler.claim_one_time_keys( | self.handler.claim_one_time_keys( | ||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None | {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None | ||||
) | ) | ||||
) | ) | ||||
self.assertEqual( | self.assertEqual( | ||||
res, | |||||
claim_res, | |||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, | {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, | ||||
) | ) | ||||
@@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
) | ) | ||||
res = self.get_success( | |||||
unused_res = self.get_success( | |||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | ||||
) | ) | ||||
self.assertEqual(res, []) | |||||
self.assertEqual(unused_res, []) | |||||
# uploading a new fallback key should result in an unused fallback key | # uploading a new fallback key should result in an unused fallback key | ||||
self.get_success( | self.get_success( | ||||
@@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
) | ) | ||||
res = self.get_success( | |||||
unused_res = self.get_success( | |||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | self.store.get_e2e_unused_fallback_key_types(local_user, device_id) | ||||
) | ) | ||||
self.assertEqual(res, ["alg1"]) | |||||
self.assertEqual(unused_res, ["alg1"]) | |||||
# if the user uploads a one-time key, the next claim should fetch the | # if the user uploads a one-time key, the next claim should fetch the | ||||
# one-time key, and then go back to the fallback | # one-time key, and then go back to the fallback | ||||
@@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
) | ) | ||||
res = self.get_success( | |||||
claim_res = self.get_success( | |||||
self.handler.claim_one_time_keys( | self.handler.claim_one_time_keys( | ||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None | {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None | ||||
) | ) | ||||
) | ) | ||||
self.assertEqual( | self.assertEqual( | ||||
res, | |||||
claim_res, | |||||
{"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, | {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, | ||||
) | ) | ||||
res = self.get_success( | |||||
claim_res = self.get_success( | |||||
self.handler.claim_one_time_keys( | self.handler.claim_one_time_keys( | ||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None | {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None | ||||
) | ) | ||||
) | ) | ||||
self.assertEqual( | self.assertEqual( | ||||
res, | |||||
claim_res, | |||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}}, | {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}}, | ||||
) | ) | ||||
@@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
) | ) | ||||
res = self.get_success( | |||||
claim_res = self.get_success( | |||||
self.handler.claim_one_time_keys( | self.handler.claim_one_time_keys( | ||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None | {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None | ||||
) | ) | ||||
) | ) | ||||
self.assertEqual( | self.assertEqual( | ||||
res, | |||||
claim_res, | |||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, | {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, | ||||
) | ) | ||||
@@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) | self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) | ||||
# upload two device keys, which will be signed later by the self-signing key | # upload two device keys, which will be signed later by the self-signing key | ||||
device_key_1 = { | |||||
device_key_1: JsonDict = { | |||||
"user_id": local_user, | "user_id": local_user, | ||||
"device_id": "abc", | "device_id": "abc", | ||||
"algorithms": [ | "algorithms": [ | ||||
@@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
}, | }, | ||||
"signatures": {local_user: {"ed25519:abc": "base64+signature"}}, | "signatures": {local_user: {"ed25519:abc": "base64+signature"}}, | ||||
} | } | ||||
device_key_2 = { | |||||
device_key_2: JsonDict = { | |||||
"user_id": local_user, | "user_id": local_user, | ||||
"device_id": "def", | "device_id": "def", | ||||
"algorithms": [ | "algorithms": [ | ||||
@@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
} | } | ||||
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) | self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) | ||||
device_handler = self.hs.get_device_handler() | |||||
assert isinstance(device_handler, DeviceHandler) | |||||
e = self.get_failure( | e = self.get_failure( | ||||
self.hs.get_device_handler().check_device_registered( | |||||
device_handler.check_device_registered( | |||||
user_id=local_user, | user_id=local_user, | ||||
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", | device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", | ||||
initial_device_display_name="new display name", | initial_device_display_name="new display name", | ||||
@@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
device_id = "xyz" | device_id = "xyz" | ||||
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA | # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA | ||||
device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY" | device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY" | ||||
device_key = { | |||||
device_key: JsonDict = { | |||||
"user_id": local_user, | "user_id": local_user, | ||||
"device_id": device_id, | "device_id": device_id, | ||||
"algorithms": [ | "algorithms": [ | ||||
@@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 | # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 | ||||
master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" | master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" | ||||
master_key = { | |||||
master_key: JsonDict = { | |||||
"user_id": local_user, | "user_id": local_user, | ||||
"usage": ["master"], | "usage": ["master"], | ||||
"keys": {"ed25519:" + master_pubkey: master_pubkey}, | "keys": {"ed25519:" + master_pubkey: master_pubkey}, | ||||
@@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
# the first user | # the first user | ||||
other_user = "@otherboris:" + self.hs.hostname | other_user = "@otherboris:" + self.hs.hostname | ||||
other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM" | other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM" | ||||
other_master_key = { | |||||
other_master_key: JsonDict = { | |||||
# private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI | # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI | ||||
"user_id": other_user, | "user_id": other_user, | ||||
"usage": ["master"], | "usage": ["master"], | ||||
@@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" | remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" | ||||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" | remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" | ||||
self.hs.get_federation_client().query_client_keys = mock.Mock( | |||||
self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable( | return_value=make_awaitable( | ||||
{ | { | ||||
"device_keys": {remote_user_id: {}}, | "device_keys": {remote_user_id: {}}, | ||||
@@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||||
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" | remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" | ||||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" | remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" | ||||
self.hs.get_federation_client().query_user_devices = mock.Mock( | |||||
self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable( | return_value=make_awaitable( | ||||
{ | { | ||||
"user_id": remote_user_id, | "user_id": remote_user_id, | ||||
@@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): | |||||
# We mock out the FederationClient.backfill method, to pretend that a remote | # We mock out the FederationClient.backfill method, to pretend that a remote | ||||
# server has returned our fake event. | # server has returned our fake event. | ||||
federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) | federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) | ||||
self.hs.get_federation_client().backfill = federation_client_backfill_mock | |||||
self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment] | |||||
# We also mock the persist method with a side effect of itself. This allows us | # We also mock the persist method with a side effect of itself. This allows us | ||||
# to track when it has been called while preserving its function. | # to track when it has been called while preserving its function. | ||||
persist_events_and_notify_mock = Mock( | persist_events_and_notify_mock = Mock( | ||||
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify | side_effect=self.hs.get_federation_event_handler().persist_events_and_notify | ||||
) | ) | ||||
self.hs.get_federation_event_handler().persist_events_and_notify = ( | |||||
self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment] | |||||
persist_events_and_notify_mock | persist_events_and_notify_mock | ||||
) | ) | ||||
@@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): | |||||
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room | fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room | ||||
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): | ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): | ||||
# Start the partial state sync. | # Start the partial state sync. | ||||
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") | |||||
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") | |||||
self.assertEqual(mock_sync_partial_state_room.call_count, 1) | self.assertEqual(mock_sync_partial_state_room.call_count, 1) | ||||
# Try to start another partial state sync. | # Try to start another partial state sync. | ||||
# Nothing should happen. | # Nothing should happen. | ||||
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") | |||||
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id") | |||||
self.assertEqual(mock_sync_partial_state_room.call_count, 1) | self.assertEqual(mock_sync_partial_state_room.call_count, 1) | ||||
# End the partial state sync | # End the partial state sync | ||||
@@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): | |||||
# The next attempt to start the partial state sync should work. | # The next attempt to start the partial state sync should work. | ||||
is_partial_state = True | is_partial_state = True | ||||
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") | |||||
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id") | |||||
self.assertEqual(mock_sync_partial_state_room.call_count, 2) | self.assertEqual(mock_sync_partial_state_room.call_count, 2) | ||||
def test_partial_state_room_sync_restart(self) -> None: | def test_partial_state_room_sync_restart(self) -> None: | ||||
@@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): | |||||
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room | fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room | ||||
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): | ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): | ||||
# Start the partial state sync. | # Start the partial state sync. | ||||
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") | |||||
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") | |||||
self.assertEqual(mock_sync_partial_state_room.call_count, 1) | self.assertEqual(mock_sync_partial_state_room.call_count, 1) | ||||
# Fail the partial state sync. | # Fail the partial state sync. | ||||
@@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): | |||||
self.assertEqual(mock_sync_partial_state_room.call_count, 1) | self.assertEqual(mock_sync_partial_state_room.call_count, 1) | ||||
# Start the partial state sync again. | # Start the partial state sync again. | ||||
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") | |||||
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") | |||||
self.assertEqual(mock_sync_partial_state_room.call_count, 2) | self.assertEqual(mock_sync_partial_state_room.call_count, 2) | ||||
# Deduplicate another partial state sync. | # Deduplicate another partial state sync. | ||||
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") | |||||
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id") | |||||
self.assertEqual(mock_sync_partial_state_room.call_count, 2) | self.assertEqual(mock_sync_partial_state_room.call_count, 2) | ||||
# Fail the partial state sync. | # Fail the partial state sync. | ||||
@@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): | |||||
self.assertEqual(mock_sync_partial_state_room.call_count, 3) | self.assertEqual(mock_sync_partial_state_room.call_count, 3) | ||||
mock_sync_partial_state_room.assert_called_with( | mock_sync_partial_state_room.assert_called_with( | ||||
initial_destination="hs3", | initial_destination="hs3", | ||||
other_destinations=["hs2"], | |||||
other_destinations={"hs2"}, | |||||
room_id="room_id", | room_id="room_id", | ||||
) | ) |
@@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext | |||||
from synapse.rest import admin | from synapse.rest import admin | ||||
from synapse.rest.client import login, room | from synapse.rest.client import login, room | ||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
from synapse.state import StateResolutionStore | |||||
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort | from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort | ||||
from synapse.types import JsonDict | from synapse.types import JsonDict | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
@@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||||
if prev_exists_as_outlier: | if prev_exists_as_outlier: | ||||
prev_event.internal_metadata.outlier = True | prev_event.internal_metadata.outlier = True | ||||
persistence = self.hs.get_storage_controllers().persistence | persistence = self.hs.get_storage_controllers().persistence | ||||
assert persistence is not None | |||||
self.get_success( | self.get_success( | ||||
persistence.persist_event( | persistence.persist_event( | ||||
prev_event, | prev_event, | ||||
@@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||||
bert_member_event.event_id: bert_member_event, | bert_member_event.event_id: bert_member_event, | ||||
rejected_kick_event.event_id: rejected_kick_event, | rejected_kick_event.event_id: rejected_kick_event, | ||||
}, | }, | ||||
state_res_store=main_store, | |||||
state_res_store=StateResolutionStore(main_store), | |||||
) | ) | ||||
), | ), | ||||
[bert_member_event.event_id, rejected_kick_event.event_id], | [bert_member_event.event_id, rejected_kick_event.event_id], | ||||
@@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||||
rejected_power_levels_event.event_id, | rejected_power_levels_event.event_id, | ||||
], | ], | ||||
event_map={}, | event_map={}, | ||||
state_res_store=main_store, | |||||
state_res_store=StateResolutionStore(main_store), | |||||
full_conflicted_set=set(), | full_conflicted_set=set(), | ||||
) | ) | ||||
), | ), | ||||
@@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase): | |||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
self.handler = self.hs.get_event_creation_handler() | self.handler = self.hs.get_event_creation_handler() | ||||
self._persist_event_storage_controller = ( | |||||
self.hs.get_storage_controllers().persistence | |||||
) | |||||
persistence = self.hs.get_storage_controllers().persistence | |||||
assert persistence is not None | |||||
self._persist_event_storage_controller = persistence | |||||
self.user_id = self.register_user("tester", "foobar") | self.user_id = self.register_user("tester", "foobar") | ||||
self.access_token = self.login("tester", "foobar") | self.access_token = self.login("tester", "foobar") | ||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) | self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) | ||||
self.info = self.get_success( | |||||
info = self.get_success( | |||||
self.hs.get_datastores().main.get_user_by_access_token( | self.hs.get_datastores().main.get_user_by_access_token( | ||||
self.access_token, | self.access_token, | ||||
) | ) | ||||
) | ) | ||||
self.token_id = self.info.token_id | |||||
assert info is not None | |||||
self.token_id = info.token_id | |||||
self.requester = create_requester(self.user_id, access_token_id=self.token_id) | self.requester = create_requester(self.user_id, access_token_id=self.token_id) | ||||
@@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||||
username: The username to use for the test. | username: The username to use for the test. | ||||
registration: Whether to test with registration URLs. | registration: Whether to test with registration URLs. | ||||
""" | """ | ||||
self.hs.get_identity_handler().send_threepid_validation = Mock( | |||||
self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(0), | return_value=make_awaitable(0), | ||||
) | ) | ||||
@@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||||
@override_config({"limit_usage_by_mau": True}) | @override_config({"limit_usage_by_mau": True}) | ||||
def test_get_or_create_user_mau_not_blocked(self) -> None: | def test_get_or_create_user_mau_not_blocked(self) -> None: | ||||
self.store.count_monthly_users = Mock( | |||||
self.store.count_monthly_users = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) | return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) | ||||
) | ) | ||||
# Ensure does not throw exception | # Ensure does not throw exception | ||||
@@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||||
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: | def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: | ||||
room_alias_str = "#room:test" | room_alias_str = "#room:test" | ||||
self.store.count_real_users = Mock(return_value=make_awaitable(1)) | |||||
self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment] | |||||
self.store.is_real_user = Mock(return_value=make_awaitable(True)) | self.store.is_real_user = Mock(return_value=make_awaitable(True)) | ||||
user_id = self.get_success(self.handler.register_user(localpart="real")) | user_id = self.get_success(self.handler.register_user(localpart="real")) | ||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) | rooms = self.get_success(self.store.get_rooms_for_user(user_id)) | ||||
@@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||||
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( | def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( | ||||
self, | self, | ||||
) -> None: | ) -> None: | ||||
self.store.count_real_users = Mock(return_value=make_awaitable(2)) | |||||
self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment] | |||||
self.store.is_real_user = Mock(return_value=make_awaitable(True)) | self.store.is_real_user = Mock(return_value=make_awaitable(True)) | ||||
user_id = self.get_success(self.handler.register_user(localpart="real")) | user_id = self.get_success(self.handler.register_user(localpart="real")) | ||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) | rooms = self.get_success(self.store.get_rooms_for_user(user_id)) | ||||
@@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||||
# Ensure the room is properly not federated. | # Ensure the room is properly not federated. | ||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | ||||
assert room is not None | |||||
self.assertFalse(room["federatable"]) | self.assertFalse(room["federatable"]) | ||||
self.assertFalse(room["public"]) | self.assertFalse(room["public"]) | ||||
self.assertEqual(room["join_rules"], "public") | self.assertEqual(room["join_rules"], "public") | ||||
@@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||||
# Ensure the room is properly a public room. | # Ensure the room is properly a public room. | ||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | ||||
assert room is not None | |||||
self.assertEqual(room["join_rules"], "public") | self.assertEqual(room["join_rules"], "public") | ||||
# Both users should be in the room. | # Both users should be in the room. | ||||
@@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||||
# Ensure the room is properly a private room. | # Ensure the room is properly a private room. | ||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | ||||
assert room is not None | |||||
self.assertFalse(room["public"]) | self.assertFalse(room["public"]) | ||||
self.assertEqual(room["join_rules"], "invite") | self.assertEqual(room["join_rules"], "invite") | ||||
self.assertEqual(room["guest_access"], "can_join") | self.assertEqual(room["guest_access"], "can_join") | ||||
@@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||||
# Ensure the room is properly a private room. | # Ensure the room is properly a private room. | ||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) | ||||
assert room is not None | |||||
self.assertFalse(room["public"]) | self.assertFalse(room["public"]) | ||||
self.assertEqual(room["join_rules"], "invite") | self.assertEqual(room["join_rules"], "invite") | ||||
self.assertEqual(room["guest_access"], "can_join") | self.assertEqual(room["guest_access"], "can_join") | ||||
@@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() | |||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
# send a mocked-up SAML response to the callback | # send a mocked-up SAML response to the callback | ||||
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) | saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) | ||||
@@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() | |||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
# Map a user via SSO. | # Map a user via SSO. | ||||
saml_response = FakeAuthnResponse( | saml_response = FakeAuthnResponse( | ||||
@@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() | |||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
# mock out the error renderer too | # mock out the error renderer too | ||||
sso_handler = self.hs.get_sso_handler() | sso_handler = self.hs.get_sso_handler() | ||||
sso_handler.render_error = Mock(return_value=None) | |||||
sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] | |||||
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) | saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) | ||||
request = _mock_request() | request = _mock_request() | ||||
@@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler and error renderer | # stub out the auth handler and error renderer | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() | |||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
sso_handler = self.hs.get_sso_handler() | sso_handler = self.hs.get_sso_handler() | ||||
sso_handler.render_error = Mock(return_value=None) | |||||
sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] | |||||
# register a user to occupy the first-choice MXID | # register a user to occupy the first-choice MXID | ||||
store = self.hs.get_datastores().main | store = self.hs.get_datastores().main | ||||
@@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() | |||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
# The response doesn't have the proper userGroup or department. | # The response doesn't have the proper userGroup or department. | ||||
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) | saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) | ||||
@@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||||
mock_keyring.verify_json_for_server.return_value = make_awaitable(True) | mock_keyring.verify_json_for_server.return_value = make_awaitable(True) | ||||
# we mock out the federation client too | # we mock out the federation client too | ||||
mock_federation_client = Mock(spec=["put_json"]) | |||||
mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) | |||||
self.mock_federation_client = Mock(spec=["put_json"]) | |||||
self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) | |||||
# the tests assume that we are starting at unix time 1000 | # the tests assume that we are starting at unix time 1000 | ||||
reactor.pump((1000,)) | reactor.pump((1000,)) | ||||
@@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||||
self.mock_hs_notifier = Mock() | self.mock_hs_notifier = Mock() | ||||
hs = self.setup_test_homeserver( | hs = self.setup_test_homeserver( | ||||
notifier=self.mock_hs_notifier, | notifier=self.mock_hs_notifier, | ||||
federation_http_client=mock_federation_client, | |||||
federation_http_client=self.mock_federation_client, | |||||
keyring=mock_keyring, | keyring=mock_keyring, | ||||
replication_streams={}, | replication_streams={}, | ||||
) | ) | ||||
@@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
) | ) | ||||
put_json = self.hs.get_federation_http_client().put_json | |||||
put_json.assert_called_once_with( | |||||
self.mock_federation_client.put_json.assert_called_once_with( | |||||
"farm", | "farm", | ||||
path="/_matrix/federation/v1/send/1000000", | path="/_matrix/federation/v1/send/1000000", | ||||
data=_expect_edu_transaction( | data=_expect_edu_transaction( | ||||
@@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||||
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) | self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) | ||||
put_json = self.hs.get_federation_http_client().put_json | |||||
put_json.assert_called_once_with( | |||||
self.mock_federation_client.put_json.assert_called_once_with( | |||||
"farm", | "farm", | ||||
path="/_matrix/federation/v1/send/1000000", | path="/_matrix/federation/v1/send/1000000", | ||||
data=_expect_edu_transaction( | data=_expect_edu_transaction( | ||||
@@ -11,7 +11,7 @@ | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Tuple | |||||
from typing import Any, Tuple | |||||
from unittest.mock import Mock, patch | from unittest.mock import Mock, patch | ||||
from urllib.parse import quote | from urllib.parse import quote | ||||
@@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService | |||||
from synapse.rest.client import login, register, room, user_directory | from synapse.rest.client import login, register, room, user_directory | ||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
from synapse.storage.roommember import ProfileInfo | from synapse.storage.roommember import ProfileInfo | ||||
from synapse.types import create_requester | |||||
from synapse.types import UserProfile, create_requester | |||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from tests import unittest | from tests import unittest | ||||
@@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event | |||||
from tests.unittest import override_config | from tests.unittest import override_config | ||||
# A spam checker which doesn't implement anything, so create a bare object. | |||||
class UselessSpamChecker: | |||||
def __init__(self, config: Any): | |||||
pass | |||||
class UserDirectoryTestCase(unittest.HomeserverTestCase): | class UserDirectoryTestCase(unittest.HomeserverTestCase): | ||||
"""Tests the UserDirectoryHandler. | """Tests the UserDirectoryHandler. | ||||
@@ -773,7 +779,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): | |||||
s = self.get_success(self.handler.search_users(u1, "user2", 10)) | s = self.get_success(self.handler.search_users(u1, "user2", 10)) | ||||
self.assertEqual(len(s["results"]), 1) | self.assertEqual(len(s["results"]), 1) | ||||
async def allow_all(user_profile: ProfileInfo) -> bool: | |||||
async def allow_all(user_profile: UserProfile) -> bool: | |||||
# Allow all users. | # Allow all users. | ||||
return False | return False | ||||
@@ -787,7 +793,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): | |||||
self.assertEqual(len(s["results"]), 1) | self.assertEqual(len(s["results"]), 1) | ||||
# Configure a spam checker that filters all users. | # Configure a spam checker that filters all users. | ||||
async def block_all(user_profile: ProfileInfo) -> bool: | |||||
async def block_all(user_profile: UserProfile) -> bool: | |||||
# All users are spammy. | # All users are spammy. | ||||
return True | return True | ||||
@@ -797,6 +803,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): | |||||
s = self.get_success(self.handler.search_users(u1, "user2", 10)) | s = self.get_success(self.handler.search_users(u1, "user2", 10)) | ||||
self.assertEqual(len(s["results"]), 0) | self.assertEqual(len(s["results"]), 0) | ||||
@override_config( | |||||
{ | |||||
"spam_checker": { | |||||
"module": "tests.handlers.test_user_directory.UselessSpamChecker" | |||||
} | |||||
} | |||||
) | |||||
def test_legacy_spam_checker(self) -> None: | def test_legacy_spam_checker(self) -> None: | ||||
""" | """ | ||||
A spam checker without the expected method should be ignored. | A spam checker without the expected method should be ignored. | ||||
@@ -825,11 +838,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): | |||||
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) | self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) | ||||
self.assertEqual(public_users, set()) | self.assertEqual(public_users, set()) | ||||
# Configure a spam checker. | |||||
spam_checker = self.hs.get_spam_checker() | |||||
# The spam checker doesn't need any methods, so create a bare object. | |||||
spam_checker.spam_checker = object() | |||||
# We get one search result when searching for user2 by user1. | # We get one search result when searching for user2 by user1. | ||||
s = self.get_success(self.handler.search_users(u1, "user2", 10)) | s = self.get_success(self.handler.search_users(u1, "user2", 10)) | ||||
self.assertEqual(len(s["results"]), 1) | self.assertEqual(len(s["results"]), 1) | ||||
@@ -954,10 +962,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
context = self.get_success(unpersisted_context.persist(event)) | context = self.get_success(unpersisted_context.persist(event)) | ||||
self.get_success( | |||||
self.hs.get_storage_controllers().persistence.persist_event(event, context) | |||||
) | |||||
persistence = self.hs.get_storage_controllers().persistence | |||||
assert persistence is not None | |||||
self.get_success(persistence.persist_event(event, context)) | |||||
def test_local_user_leaving_room_remains_in_user_directory(self) -> None: | def test_local_user_leaving_room_remains_in_user_directory(self) -> None: | ||||
"""We've chosen to simplify the user directory's implementation by | """We've chosen to simplify the user directory's implementation by | ||||
@@ -68,11 +68,11 @@ class ModuleApiTestCase(BaseModuleApiTestCase): | |||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
# Mock out the calls over federation. | # Mock out the calls over federation. | ||||
fed_transport_client = Mock(spec=["send_transaction"]) | |||||
fed_transport_client.send_transaction = simple_async_mock({}) | |||||
self.fed_transport_client = Mock(spec=["send_transaction"]) | |||||
self.fed_transport_client.send_transaction = simple_async_mock({}) | |||||
return self.setup_test_homeserver( | return self.setup_test_homeserver( | ||||
federation_transport_client=fed_transport_client, | |||||
federation_transport_client=self.fed_transport_client, | |||||
) | ) | ||||
def test_can_register_user(self) -> None: | def test_can_register_user(self) -> None: | ||||
@@ -417,7 +417,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): | |||||
# | # | ||||
# Thus we reset the mock, and try sending online local user | # Thus we reset the mock, and try sending online local user | ||||
# presence again | # presence again | ||||
self.hs.get_federation_transport_client().send_transaction.reset_mock() | |||||
self.fed_transport_client.send_transaction.reset_mock() | |||||
# Broadcast local user online presence | # Broadcast local user online presence | ||||
self.get_success( | self.get_success( | ||||
@@ -429,9 +429,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): | |||||
# Check that a presence update was sent as part of a federation transaction | # Check that a presence update was sent as part of a federation transaction | ||||
found_update = False | found_update = False | ||||
calls = ( | |||||
self.hs.get_federation_transport_client().send_transaction.call_args_list | |||||
) | |||||
calls = self.fed_transport_client.send_transaction.call_args_list | |||||
for call in calls: | for call in calls: | ||||
call_args = call[0] | call_args = call[0] | ||||
federation_transaction: Transaction = call_args[0] | federation_transaction: Transaction = call_args[0] | ||||
@@ -581,7 +579,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): | |||||
mocked_remote_join = simple_async_mock( | mocked_remote_join = simple_async_mock( | ||||
return_value=("fake-event-id", fake_stream_id) | return_value=("fake-event-id", fake_stream_id) | ||||
) | ) | ||||
self.hs.get_room_member_handler()._remote_join = mocked_remote_join | |||||
self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment] | |||||
fake_remote_host = f"{self.module_api.server_name}-remote" | fake_remote_host = f"{self.module_api.server_name}-remote" | ||||
# Given that the join is to be faked, we expect the relevant join event not to | # Given that the join is to be faked, we expect the relevant join event not to | ||||
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor | |||||
import synapse.rest.admin | import synapse.rest.admin | ||||
from synapse.api.errors import Codes, SynapseError | from synapse.api.errors import Codes, SynapseError | ||||
from synapse.push.emailpusher import EmailPusher | |||||
from synapse.rest.client import login, room | from synapse.rest.client import login, room | ||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
@@ -105,6 +106,7 @@ class EmailPusherTests(HomeserverTestCase): | |||||
user_tuple = self.get_success( | user_tuple = self.get_success( | ||||
self.hs.get_datastores().main.get_user_by_access_token(self.access_token) | self.hs.get_datastores().main.get_user_by_access_token(self.access_token) | ||||
) | ) | ||||
assert user_tuple is not None | |||||
self.token_id = user_tuple.token_id | self.token_id = user_tuple.token_id | ||||
# We need to add email to account before we can create a pusher. | # We need to add email to account before we can create a pusher. | ||||
@@ -114,7 +116,7 @@ class EmailPusherTests(HomeserverTestCase): | |||||
) | ) | ||||
) | ) | ||||
self.pusher = self.get_success( | |||||
pusher = self.get_success( | |||||
self.hs.get_pusherpool().add_or_update_pusher( | self.hs.get_pusherpool().add_or_update_pusher( | ||||
user_id=self.user_id, | user_id=self.user_id, | ||||
access_token=self.token_id, | access_token=self.token_id, | ||||
@@ -127,6 +129,8 @@ class EmailPusherTests(HomeserverTestCase): | |||||
data={}, | data={}, | ||||
) | ) | ||||
) | ) | ||||
assert isinstance(pusher, EmailPusher) | |||||
self.pusher = pusher | |||||
self.auth_handler = hs.get_auth_handler() | self.auth_handler = hs.get_auth_handler() | ||||
self.store = hs.get_datastores().main | self.store = hs.get_datastores().main | ||||
@@ -375,10 +379,13 @@ class EmailPusherTests(HomeserverTestCase): | |||||
) | ) | ||||
# check that the pusher for that email address has been deleted | # check that the pusher for that email address has been deleted | ||||
pushers = self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) | |||||
pushers = list( | |||||
self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by( | |||||
{"user_name": self.user_id} | |||||
) | |||||
) | |||||
) | ) | ||||
pushers = list(pushers) | |||||
self.assertEqual(len(pushers), 0) | self.assertEqual(len(pushers), 0) | ||||
def test_remove_unlinked_pushers_background_job(self) -> None: | def test_remove_unlinked_pushers_background_job(self) -> None: | ||||
@@ -413,10 +420,13 @@ class EmailPusherTests(HomeserverTestCase): | |||||
self.wait_for_background_updates() | self.wait_for_background_updates() | ||||
# Check that all pushers with unlinked addresses were deleted | # Check that all pushers with unlinked addresses were deleted | ||||
pushers = self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) | |||||
pushers = list( | |||||
self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by( | |||||
{"user_name": self.user_id} | |||||
) | |||||
) | |||||
) | ) | ||||
pushers = list(pushers) | |||||
self.assertEqual(len(pushers), 0) | self.assertEqual(len(pushers), 0) | ||||
def _check_for_mail(self) -> Tuple[Sequence, Dict]: | def _check_for_mail(self) -> Tuple[Sequence, Dict]: | ||||
@@ -428,10 +438,13 @@ class EmailPusherTests(HomeserverTestCase): | |||||
that notification. | that notification. | ||||
""" | """ | ||||
# Get the stream ordering before it gets sent | # Get the stream ordering before it gets sent | ||||
pushers = self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) | |||||
pushers = list( | |||||
self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by( | |||||
{"user_name": self.user_id} | |||||
) | |||||
) | |||||
) | ) | ||||
pushers = list(pushers) | |||||
self.assertEqual(len(pushers), 1) | self.assertEqual(len(pushers), 1) | ||||
last_stream_ordering = pushers[0].last_stream_ordering | last_stream_ordering = pushers[0].last_stream_ordering | ||||
@@ -439,10 +452,13 @@ class EmailPusherTests(HomeserverTestCase): | |||||
self.pump(10) | self.pump(10) | ||||
# It hasn't succeeded yet, so the stream ordering shouldn't have moved | # It hasn't succeeded yet, so the stream ordering shouldn't have moved | ||||
pushers = self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) | |||||
pushers = list( | |||||
self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by( | |||||
{"user_name": self.user_id} | |||||
) | |||||
) | |||||
) | ) | ||||
pushers = list(pushers) | |||||
self.assertEqual(len(pushers), 1) | self.assertEqual(len(pushers), 1) | ||||
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) | self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) | ||||
@@ -458,10 +474,13 @@ class EmailPusherTests(HomeserverTestCase): | |||||
self.assertEqual(len(self.email_attempts), 1) | self.assertEqual(len(self.email_attempts), 1) | ||||
# The stream ordering has increased | # The stream ordering has increased | ||||
pushers = self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id}) | |||||
pushers = list( | |||||
self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by( | |||||
{"user_name": self.user_id} | |||||
) | |||||
) | |||||
) | ) | ||||
pushers = list(pushers) | |||||
self.assertEqual(len(pushers), 1) | self.assertEqual(len(pushers), 1) | ||||
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) | self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) | ||||
@@ -11,7 +11,7 @@ | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import List, Optional, Tuple | |||||
from typing import Any, List, Tuple | |||||
from unittest.mock import Mock | from unittest.mock import Mock | ||||
from twisted.internet.defer import Deferred | from twisted.internet.defer import Deferred | ||||
@@ -22,7 +22,6 @@ from synapse.logging.context import make_deferred_yieldable | |||||
from synapse.push import PusherConfig, PusherConfigException | from synapse.push import PusherConfig, PusherConfigException | ||||
from synapse.rest.client import login, push_rule, pusher, receipts, room | from synapse.rest.client import login, push_rule, pusher, receipts, room | ||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
from synapse.storage.databases.main.registration import TokenLookupResult | |||||
from synapse.types import JsonDict | from synapse.types import JsonDict | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
@@ -67,9 +66,10 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
user_tuple = self.get_success( | user_tuple = self.get_success( | ||||
self.hs.get_datastores().main.get_user_by_access_token(access_token) | self.hs.get_datastores().main.get_user_by_access_token(access_token) | ||||
) | ) | ||||
assert user_tuple is not None | |||||
token_id = user_tuple.token_id | token_id = user_tuple.token_id | ||||
def test_data(data: Optional[JsonDict]) -> None: | |||||
def test_data(data: Any) -> None: | |||||
self.get_failure( | self.get_failure( | ||||
self.hs.get_pusherpool().add_or_update_pusher( | self.hs.get_pusherpool().add_or_update_pusher( | ||||
user_id=user_id, | user_id=user_id, | ||||
@@ -113,6 +113,7 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
user_tuple = self.get_success( | user_tuple = self.get_success( | ||||
self.hs.get_datastores().main.get_user_by_access_token(access_token) | self.hs.get_datastores().main.get_user_by_access_token(access_token) | ||||
) | ) | ||||
assert user_tuple is not None | |||||
token_id = user_tuple.token_id | token_id = user_tuple.token_id | ||||
self.get_success( | self.get_success( | ||||
@@ -140,10 +141,11 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
self.helper.send(room, body="There!", tok=other_access_token) | self.helper.send(room, body="There!", tok=other_access_token) | ||||
# Get the stream ordering before it gets sent | # Get the stream ordering before it gets sent | ||||
pushers = self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) | |||||
pushers = list( | |||||
self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) | |||||
) | |||||
) | ) | ||||
pushers = list(pushers) | |||||
self.assertEqual(len(pushers), 1) | self.assertEqual(len(pushers), 1) | ||||
last_stream_ordering = pushers[0].last_stream_ordering | last_stream_ordering = pushers[0].last_stream_ordering | ||||
@@ -151,10 +153,11 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
self.pump() | self.pump() | ||||
# It hasn't succeeded yet, so the stream ordering shouldn't have moved | # It hasn't succeeded yet, so the stream ordering shouldn't have moved | ||||
pushers = self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) | |||||
pushers = list( | |||||
self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) | |||||
) | |||||
) | ) | ||||
pushers = list(pushers) | |||||
self.assertEqual(len(pushers), 1) | self.assertEqual(len(pushers), 1) | ||||
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) | self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering) | ||||
@@ -172,10 +175,11 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
self.pump() | self.pump() | ||||
# The stream ordering has increased | # The stream ordering has increased | ||||
pushers = self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) | |||||
pushers = list( | |||||
self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) | |||||
) | |||||
) | ) | ||||
pushers = list(pushers) | |||||
self.assertEqual(len(pushers), 1) | self.assertEqual(len(pushers), 1) | ||||
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) | self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) | ||||
last_stream_ordering = pushers[0].last_stream_ordering | last_stream_ordering = pushers[0].last_stream_ordering | ||||
@@ -194,10 +198,11 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
self.pump() | self.pump() | ||||
# The stream ordering has increased, again | # The stream ordering has increased, again | ||||
pushers = self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) | |||||
pushers = list( | |||||
self.get_success( | |||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) | |||||
) | |||||
) | ) | ||||
pushers = list(pushers) | |||||
self.assertEqual(len(pushers), 1) | self.assertEqual(len(pushers), 1) | ||||
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) | self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering) | ||||
@@ -229,6 +234,7 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
user_tuple = self.get_success( | user_tuple = self.get_success( | ||||
self.hs.get_datastores().main.get_user_by_access_token(access_token) | self.hs.get_datastores().main.get_user_by_access_token(access_token) | ||||
) | ) | ||||
assert user_tuple is not None | |||||
token_id = user_tuple.token_id | token_id = user_tuple.token_id | ||||
self.get_success( | self.get_success( | ||||
@@ -349,6 +355,7 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
user_tuple = self.get_success( | user_tuple = self.get_success( | ||||
self.hs.get_datastores().main.get_user_by_access_token(access_token) | self.hs.get_datastores().main.get_user_by_access_token(access_token) | ||||
) | ) | ||||
assert user_tuple is not None | |||||
token_id = user_tuple.token_id | token_id = user_tuple.token_id | ||||
self.get_success( | self.get_success( | ||||
@@ -435,6 +442,7 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
user_tuple = self.get_success( | user_tuple = self.get_success( | ||||
self.hs.get_datastores().main.get_user_by_access_token(access_token) | self.hs.get_datastores().main.get_user_by_access_token(access_token) | ||||
) | ) | ||||
assert user_tuple is not None | |||||
token_id = user_tuple.token_id | token_id = user_tuple.token_id | ||||
self.get_success( | self.get_success( | ||||
@@ -512,6 +520,7 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
user_tuple = self.get_success( | user_tuple = self.get_success( | ||||
self.hs.get_datastores().main.get_user_by_access_token(access_token) | self.hs.get_datastores().main.get_user_by_access_token(access_token) | ||||
) | ) | ||||
assert user_tuple is not None | |||||
token_id = user_tuple.token_id | token_id = user_tuple.token_id | ||||
self.get_success( | self.get_success( | ||||
@@ -618,6 +627,7 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
user_tuple = self.get_success( | user_tuple = self.get_success( | ||||
self.hs.get_datastores().main.get_user_by_access_token(access_token) | self.hs.get_datastores().main.get_user_by_access_token(access_token) | ||||
) | ) | ||||
assert user_tuple is not None | |||||
token_id = user_tuple.token_id | token_id = user_tuple.token_id | ||||
self.get_success( | self.get_success( | ||||
@@ -753,6 +763,7 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
user_tuple = self.get_success( | user_tuple = self.get_success( | ||||
self.hs.get_datastores().main.get_user_by_access_token(access_token) | self.hs.get_datastores().main.get_user_by_access_token(access_token) | ||||
) | ) | ||||
assert user_tuple is not None | |||||
token_id = user_tuple.token_id | token_id = user_tuple.token_id | ||||
self.get_success( | self.get_success( | ||||
@@ -895,6 +906,7 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
user_tuple = self.get_success( | user_tuple = self.get_success( | ||||
self.hs.get_datastores().main.get_user_by_access_token(access_token) | self.hs.get_datastores().main.get_user_by_access_token(access_token) | ||||
) | ) | ||||
assert user_tuple is not None | |||||
token_id = user_tuple.token_id | token_id = user_tuple.token_id | ||||
device_id = user_tuple.device_id | device_id = user_tuple.device_id | ||||
@@ -941,9 +953,10 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
) | ) | ||||
# Look up the user info for the access token so we can compare the device ID. | # Look up the user info for the access token so we can compare the device ID. | ||||
lookup_result: TokenLookupResult = self.get_success( | |||||
lookup_result = self.get_success( | |||||
self.hs.get_datastores().main.get_user_by_access_token(access_token) | self.hs.get_datastores().main.get_user_by_access_token(access_token) | ||||
) | ) | ||||
assert lookup_result is not None | |||||
# Get the user's devices and check it has the correct device ID. | # Get the user's devices and check it has the correct device ID. | ||||
channel = self.make_request("GET", "/pushers", access_token=access_token) | channel = self.make_request("GET", "/pushers", access_token=access_token) | ||||
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Any, List, Optional | |||||
from typing import Any, List, Optional, Sequence | |||||
from twisted.test.proto_helpers import MemoryReactor | from twisted.test.proto_helpers import MemoryReactor | ||||
@@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase): | |||||
) | ) | ||||
# this is the point in the DAG where we make a fork | # this is the point in the DAG where we make a fork | ||||
fork_point: List[str] = self.get_success( | |||||
fork_point: Sequence[str] = self.get_success( | |||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) | self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) | ||||
) | ) | ||||
@@ -168,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase): | |||||
pl_event = self.get_success( | pl_event = self.get_success( | ||||
inject_event( | inject_event( | ||||
self.hs, | self.hs, | ||||
prev_event_ids=prev_events, | |||||
prev_event_ids=list(prev_events), | |||||
type=EventTypes.PowerLevels, | type=EventTypes.PowerLevels, | ||||
state_key="", | state_key="", | ||||
sender=self.user_id, | sender=self.user_id, | ||||
@@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase): | |||||
) | ) | ||||
# this is the point in the DAG where we make a fork | # this is the point in the DAG where we make a fork | ||||
fork_point: List[str] = self.get_success( | |||||
fork_point: Sequence[str] = self.get_success( | |||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) | self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) | ||||
) | ) | ||||
@@ -323,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase): | |||||
e = self.get_success( | e = self.get_success( | ||||
inject_event( | inject_event( | ||||
self.hs, | self.hs, | ||||
prev_event_ids=prev_events, | |||||
prev_event_ids=list(prev_events), | |||||
type=EventTypes.PowerLevels, | type=EventTypes.PowerLevels, | ||||
state_key="", | state_key="", | ||||
sender=self.user_id, | sender=self.user_id, | ||||
@@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase): | |||||
room_id = self.helper.create_room_as("@bob:test") | room_id = self.helper.create_room_as("@bob:test") | ||||
# Mark the room as partial-stated. | # Mark the room as partial-stated. | ||||
self.get_success( | self.get_success( | ||||
self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1") | |||||
self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1") | |||||
) | ) | ||||
worker = self.make_worker_hs("synapse.app.generic_worker") | worker = self.make_worker_hs("synapse.app.generic_worker") | ||||
@@ -13,7 +13,7 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
from unittest.mock import Mock | from unittest.mock import Mock | ||||
from synapse.handlers.typing import RoomMember | |||||
from synapse.handlers.typing import RoomMember, TypingWriterHandler | |||||
from synapse.replication.tcp.streams import TypingStream | from synapse.replication.tcp.streams import TypingStream | ||||
from synapse.util.caches.stream_change_cache import StreamChangeCache | from synapse.util.caches.stream_change_cache import StreamChangeCache | ||||
@@ -33,6 +33,7 @@ class TypingStreamTestCase(BaseStreamTestCase): | |||||
def test_typing(self) -> None: | def test_typing(self) -> None: | ||||
typing = self.hs.get_typing_handler() | typing = self.hs.get_typing_handler() | ||||
assert isinstance(typing, TypingWriterHandler) | |||||
self.reconnect() | self.reconnect() | ||||
@@ -88,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase): | |||||
sends the proper position and RDATA). | sends the proper position and RDATA). | ||||
""" | """ | ||||
typing = self.hs.get_typing_handler() | typing = self.hs.get_typing_handler() | ||||
assert isinstance(typing, TypingWriterHandler) | |||||
self.reconnect() | self.reconnect() | ||||
@@ -127,6 +127,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase): | |||||
# ... updating the cache ID gen on the master still shouldn't cause the | # ... updating the cache ID gen on the master still shouldn't cause the | ||||
# deferred to wake up. | # deferred to wake up. | ||||
assert store._cache_id_gen is not None | |||||
ctx = store._cache_id_gen.get_next() | ctx = store._cache_id_gen.get_next() | ||||
self.get_success(ctx.__aenter__()) | self.get_success(ctx.__aenter__()) | ||||
self.get_success(ctx.__aexit__(None, None, None)) | self.get_success(ctx.__aexit__(None, None, None)) | ||||
@@ -16,6 +16,7 @@ from unittest.mock import Mock | |||||
from synapse.api.constants import EventTypes, Membership | from synapse.api.constants import EventTypes, Membership | ||||
from synapse.events.builder import EventBuilderFactory | from synapse.events.builder import EventBuilderFactory | ||||
from synapse.handlers.typing import TypingWriterHandler | |||||
from synapse.rest.admin import register_servlets_for_client_rest_resource | from synapse.rest.admin import register_servlets_for_client_rest_resource | ||||
from synapse.rest.client import login, room | from synapse.rest.client import login, room | ||||
from synapse.types import UserID, create_requester | from synapse.types import UserID, create_requester | ||||
@@ -174,6 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): | |||||
token = self.login("user3", "pass") | token = self.login("user3", "pass") | ||||
typing_handler = self.hs.get_typing_handler() | typing_handler = self.hs.get_typing_handler() | ||||
assert isinstance(typing_handler, TypingWriterHandler) | |||||
sent_on_1 = False | sent_on_1 = False | ||||
sent_on_2 = False | sent_on_2 = False | ||||
@@ -50,6 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): | |||||
user_dict = self.get_success( | user_dict = self.get_success( | ||||
self.hs.get_datastores().main.get_user_by_access_token(access_token) | self.hs.get_datastores().main.get_user_by_access_token(access_token) | ||||
) | ) | ||||
assert user_dict is not None | |||||
token_id = user_dict.token_id | token_id = user_dict.token_id | ||||
self.get_success( | self.get_success( | ||||
@@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): | |||||
other_user_tok = self.login("user", "pass") | other_user_tok = self.login("user", "pass") | ||||
event_builder_factory = self.hs.get_event_builder_factory() | event_builder_factory = self.hs.get_event_builder_factory() | ||||
event_creation_handler = self.hs.get_event_creation_handler() | event_creation_handler = self.hs.get_event_creation_handler() | ||||
storage_controllers = self.hs.get_storage_controllers() | |||||
persistence = self.hs.get_storage_controllers().persistence | |||||
assert persistence is not None | |||||
# Create two rooms, one with a local user only and one with both a local | # Create two rooms, one with a local user only and one with both a local | ||||
# and remote user. | # and remote user. | ||||
@@ -2940,7 +2941,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): | |||||
context = self.get_success(unpersisted_context.persist(event)) | context = self.get_success(unpersisted_context.persist(event)) | ||||
self.get_success(storage_controllers.persistence.persist_event(event, context)) | |||||
self.get_success(persistence.persist_event(event, context)) | |||||
# Now get rooms | # Now get rooms | ||||
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" | url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms" | ||||
@@ -11,6 +11,8 @@ | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Optional | |||||
from twisted.test.proto_helpers import MemoryReactor | from twisted.test.proto_helpers import MemoryReactor | ||||
import synapse.rest.admin | import synapse.rest.admin | ||||
@@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): | |||||
self.register_user("admin", "pass", admin=True) | self.register_user("admin", "pass", admin=True) | ||||
self.admin_user_tok = self.login("admin", "pass") | self.admin_user_tok = self.login("admin", "pass") | ||||
async def check_username(username: str) -> bool: | |||||
if username == "allowed": | |||||
return True | |||||
async def check_username( | |||||
localpart: str, | |||||
guest_access_token: Optional[str] = None, | |||||
assigned_user_id: Optional[str] = None, | |||||
inhibit_user_in_use_error: bool = False, | |||||
) -> None: | |||||
if localpart == "allowed": | |||||
return | |||||
raise SynapseError( | raise SynapseError( | ||||
400, | 400, | ||||
"User ID already taken.", | "User ID already taken.", | ||||
@@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
handler = self.hs.get_registration_handler() | handler = self.hs.get_registration_handler() | ||||
handler.check_username = check_username | |||||
handler.check_username = check_username # type: ignore[assignment] | |||||
def test_username_available(self) -> None: | def test_username_available(self) -> None: | ||||
""" | """ | ||||
@@ -1193,7 +1193,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): | |||||
return {} | return {} | ||||
# Register a mock that will return the expected result depending on the remote. | # Register a mock that will return the expected result depending on the remote. | ||||
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) | |||||
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment] | |||||
# Check that we've got the correct response from the client-side endpoint. | # Check that we've got the correct response from the client-side endpoint. | ||||
self._test_status( | self._test_status( | ||||
@@ -63,14 +63,14 @@ class FilterTestCase(unittest.HomeserverTestCase): | |||||
def test_add_filter_non_local_user(self) -> None: | def test_add_filter_non_local_user(self) -> None: | ||||
_is_mine = self.hs.is_mine | _is_mine = self.hs.is_mine | ||||
self.hs.is_mine = lambda target_user: False | |||||
self.hs.is_mine = lambda target_user: False # type: ignore[assignment] | |||||
channel = self.make_request( | channel = self.make_request( | ||||
"POST", | "POST", | ||||
"/_matrix/client/r0/user/%s/filter" % (self.user_id), | "/_matrix/client/r0/user/%s/filter" % (self.user_id), | ||||
self.EXAMPLE_FILTER_JSON, | self.EXAMPLE_FILTER_JSON, | ||||
) | ) | ||||
self.hs.is_mine = _is_mine | |||||
self.hs.is_mine = _is_mine # type: ignore[assignment] | |||||
self.assertEqual(channel.code, 403) | self.assertEqual(channel.code, 403) | ||||
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) | self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) | ||||
@@ -36,14 +36,14 @@ class PresenceTestCase(unittest.HomeserverTestCase): | |||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
presence_handler = Mock(spec=PresenceHandler) | |||||
presence_handler.set_state.return_value = make_awaitable(None) | |||||
self.presence_handler = Mock(spec=PresenceHandler) | |||||
self.presence_handler.set_state.return_value = make_awaitable(None) | |||||
hs = self.setup_test_homeserver( | hs = self.setup_test_homeserver( | ||||
"red", | "red", | ||||
federation_http_client=None, | federation_http_client=None, | ||||
federation_client=Mock(), | federation_client=Mock(), | ||||
presence_handler=presence_handler, | |||||
presence_handler=self.presence_handler, | |||||
) | ) | ||||
return hs | return hs | ||||
@@ -61,7 +61,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
self.assertEqual(channel.code, HTTPStatus.OK) | self.assertEqual(channel.code, HTTPStatus.OK) | ||||
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) | |||||
self.assertEqual(self.presence_handler.set_state.call_count, 1) | |||||
@unittest.override_config({"use_presence": False}) | @unittest.override_config({"use_presence": False}) | ||||
def test_put_presence_disabled(self) -> None: | def test_put_presence_disabled(self) -> None: | ||||
@@ -76,4 +76,4 @@ class PresenceTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
self.assertEqual(channel.code, HTTPStatus.OK) | self.assertEqual(channel.code, HTTPStatus.OK) | ||||
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) | |||||
self.assertEqual(self.presence_handler.set_state.call_count, 0) |
@@ -151,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): | |||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") | self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") | ||||
def test_POST_guest_registration(self) -> None: | def test_POST_guest_registration(self) -> None: | ||||
self.hs.config.key.macaroon_secret_key = "test" | |||||
self.hs.config.key.macaroon_secret_key = b"test" | |||||
self.hs.config.registration.allow_guest_access = True | self.hs.config.registration.allow_guest_access = True | ||||
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") | channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") | ||||
@@ -1166,12 +1166,15 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): | |||||
""" | """ | ||||
user_id = self.register_user("kermit_delta", "user") | user_id = self.register_user("kermit_delta", "user") | ||||
self.hs.config.account_validity.startup_job_max_delta = self.max_delta | |||||
self.hs.config.account_validity.account_validity_startup_job_max_delta = ( | |||||
self.max_delta | |||||
) | |||||
now_ms = self.hs.get_clock().time_msec() | now_ms = self.hs.get_clock().time_msec() | ||||
self.get_success(self.store._set_expiration_date_when_missing()) | self.get_success(self.store._set_expiration_date_when_missing()) | ||||
res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) | res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) | ||||
assert res is not None | |||||
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) | self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) | ||||
self.assertLessEqual(res, now_ms + self.validity_period) | self.assertLessEqual(res, now_ms + self.validity_period) | ||||
@@ -136,6 +136,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): | |||||
# Send a first event, which should be filtered out at the end of the test. | # Send a first event, which should be filtered out at the end of the test. | ||||
resp = self.helper.send(room_id=room_id, body="1", tok=self.token) | resp = self.helper.send(room_id=room_id, body="1", tok=self.token) | ||||
first_event_id = resp.get("event_id") | first_event_id = resp.get("event_id") | ||||
assert isinstance(first_event_id, str) | |||||
# Advance the time by 2 days. We're using the default retention policy, therefore | # Advance the time by 2 days. We're using the default retention policy, therefore | ||||
# after this the first event will still be valid. | # after this the first event will still be valid. | ||||
@@ -144,6 +145,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): | |||||
# Send another event, which shouldn't get filtered out. | # Send another event, which shouldn't get filtered out. | ||||
resp = self.helper.send(room_id=room_id, body="2", tok=self.token) | resp = self.helper.send(room_id=room_id, body="2", tok=self.token) | ||||
valid_event_id = resp.get("event_id") | valid_event_id = resp.get("event_id") | ||||
assert isinstance(valid_event_id, str) | |||||
# Advance the time by another 2 days. After this, the first event should be | # Advance the time by another 2 days. After this, the first event should be | ||||
# outdated but not the second one. | # outdated but not the second one. | ||||
@@ -229,7 +231,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): | |||||
# Check that we can still access state events that were sent before the event that | # Check that we can still access state events that were sent before the event that | ||||
# has been purged. | # has been purged. | ||||
self.get_event(room_id, create_event.event_id) | |||||
self.get_event(room_id, bool(create_event)) | |||||
def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict: | def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict: | ||||
event = self.get_success(self.store.get_event(event_id, allow_none=True)) | event = self.get_success(self.store.get_event(event_id, allow_none=True)) | ||||
@@ -238,7 +240,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): | |||||
self.assertIsNone(event) | self.assertIsNone(event) | ||||
return {} | return {} | ||||
self.assertIsNotNone(event) | |||||
assert event is not None | |||||
time_now = self.clock.time_msec() | time_now = self.clock.time_msec() | ||||
serialized = self.serializer.serialize_event(event, time_now) | serialized = self.serializer.serialize_event(event, time_now) | ||||
@@ -3382,8 +3382,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): | |||||
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we | # a remote IS. We keep the mock for make_and_store_3pid_invite around so we | ||||
# can check its call_count later on during the test. | # can check its call_count later on during the test. | ||||
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) | make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) | ||||
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock | |||||
self.hs.get_identity_handler().lookup_3pid = Mock( | |||||
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] | |||||
self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(None), | return_value=make_awaitable(None), | ||||
) | ) | ||||
@@ -3443,8 +3443,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): | |||||
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we | # a remote IS. We keep the mock for make_and_store_3pid_invite around so we | ||||
# can check its call_count later on during the test. | # can check its call_count later on during the test. | ||||
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) | make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) | ||||
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock | |||||
self.hs.get_identity_handler().lookup_3pid = Mock( | |||||
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] | |||||
self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(None), | return_value=make_awaitable(None), | ||||
) | ) | ||||
@@ -3563,8 +3563,10 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
event.internal_metadata.outlier = True | event.internal_metadata.outlier = True | ||||
persistence = self._storage_controllers.persistence | |||||
assert persistence is not None | |||||
self.get_success( | self.get_success( | ||||
self._storage_controllers.persistence.persist_event( | |||||
persistence.persist_event( | |||||
event, EventContext.for_outlier(self._storage_controllers) | event, EventContext.for_outlier(self._storage_controllers) | ||||
) | ) | ||||
) | ) | ||||
@@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase): | |||||
def test_invite_3pid(self) -> None: | def test_invite_3pid(self) -> None: | ||||
"""Ensure that a 3PID invite does not attempt to contact the identity server.""" | """Ensure that a 3PID invite does not attempt to contact the identity server.""" | ||||
identity_handler = self.hs.get_identity_handler() | identity_handler = self.hs.get_identity_handler() | ||||
identity_handler.lookup_3pid = Mock( | |||||
identity_handler.lookup_3pid = Mock( # type: ignore[assignment] | |||||
side_effect=AssertionError("This should not get called") | side_effect=AssertionError("This should not get called") | ||||
) | ) | ||||
@@ -222,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase): | |||||
event_source.get_new_events( | event_source.get_new_events( | ||||
user=UserID.from_string(self.other_user_id), | user=UserID.from_string(self.other_user_id), | ||||
from_key=0, | from_key=0, | ||||
limit=None, | |||||
limit=10, | |||||
room_ids=[room_id], | room_ids=[room_id], | ||||
is_guest=False, | is_guest=False, | ||||
) | ) | ||||
@@ -286,6 +286,7 @@ class ProfileTestCase(_ShadowBannedBase): | |||||
self.banned_user_id, | self.banned_user_id, | ||||
) | ) | ||||
) | ) | ||||
assert event is not None | |||||
self.assertEqual( | self.assertEqual( | ||||
event.content, {"membership": "join", "displayname": original_display_name} | event.content, {"membership": "join", "displayname": original_display_name} | ||||
) | ) | ||||
@@ -321,6 +322,7 @@ class ProfileTestCase(_ShadowBannedBase): | |||||
self.banned_user_id, | self.banned_user_id, | ||||
) | ) | ||||
) | ) | ||||
assert event is not None | |||||
self.assertEqual( | self.assertEqual( | ||||
event.content, {"membership": "join", "displayname": original_display_name} | event.content, {"membership": "join", "displayname": original_display_name} | ||||
) | ) |
@@ -84,7 +84,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): | |||||
self.room_id, EventTypes.Tombstone, "" | self.room_id, EventTypes.Tombstone, "" | ||||
) | ) | ||||
) | ) | ||||
self.assertIsNotNone(tombstone_event) | |||||
assert tombstone_event is not None | |||||
self.assertEqual(new_room_id, tombstone_event.content["replacement_room"]) | self.assertEqual(new_room_id, tombstone_event.content["replacement_room"]) | ||||
# Check that the new room exists. | # Check that the new room exists. | ||||
@@ -24,6 +24,7 @@ from synapse.server import HomeServer | |||||
from synapse.server_notices.resource_limits_server_notices import ( | from synapse.server_notices.resource_limits_server_notices import ( | ||||
ResourceLimitsServerNotices, | ResourceLimitsServerNotices, | ||||
) | ) | ||||
from synapse.server_notices.server_notices_sender import ServerNoticesSender | |||||
from synapse.types import JsonDict | from synapse.types import JsonDict | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
@@ -58,14 +59,15 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
return config | return config | ||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
self.server_notices_sender = self.hs.get_server_notices_sender() | |||||
server_notices_sender = self.hs.get_server_notices_sender() | |||||
assert isinstance(server_notices_sender, ServerNoticesSender) | |||||
# relying on [1] is far from ideal, but the only case where | # relying on [1] is far from ideal, but the only case where | ||||
# ResourceLimitsServerNotices class needs to be isolated is this test, | # ResourceLimitsServerNotices class needs to be isolated is this test, | ||||
# general code should never have a reason to do so ... | # general code should never have a reason to do so ... | ||||
self._rlsn = self.server_notices_sender._server_notices[1] | |||||
if not isinstance(self._rlsn, ResourceLimitsServerNotices): | |||||
raise Exception("Failed to find reference to ResourceLimitsServerNotices") | |||||
rlsn = list(server_notices_sender._server_notices)[1] | |||||
assert isinstance(rlsn, ResourceLimitsServerNotices) | |||||
self._rlsn = rlsn | |||||
self._rlsn._store.user_last_seen_monthly_active = Mock( | self._rlsn._store.user_last_seen_monthly_active = Mock( | ||||
return_value=make_awaitable(1000) | return_value=make_awaitable(1000) | ||||
@@ -101,25 +103,29 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: | def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: | ||||
"""Test when user has blocked notice, but should have it removed""" | """Test when user has blocked notice, but should have it removed""" | ||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( | |||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(None) | return_value=make_awaitable(None) | ||||
) | ) | ||||
mock_event = Mock( | mock_event = Mock( | ||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} | type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} | ||||
) | ) | ||||
self._rlsn._store.get_events = Mock( | |||||
self._rlsn._store.get_events = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable({"123": mock_event}) | return_value=make_awaitable({"123": mock_event}) | ||||
) | ) | ||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | ||||
# Would be better to check the content, but once == remove blocking event | # Would be better to check the content, but once == remove blocking event | ||||
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once() | |||||
maybe_get_notice_room_for_user = ( | |||||
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user | |||||
) | |||||
assert isinstance(maybe_get_notice_room_for_user, Mock) | |||||
maybe_get_notice_room_for_user.assert_called_once() | |||||
self._send_notice.assert_called_once() | self._send_notice.assert_called_once() | ||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None: | def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None: | ||||
""" | """ | ||||
Test when user has blocked notice, but notice ought to be there (NOOP) | Test when user has blocked notice, but notice ought to be there (NOOP) | ||||
""" | """ | ||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( | |||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(None), | return_value=make_awaitable(None), | ||||
side_effect=ResourceLimitError(403, "foo"), | side_effect=ResourceLimitError(403, "foo"), | ||||
) | ) | ||||
@@ -127,7 +133,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
mock_event = Mock( | mock_event = Mock( | ||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} | type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} | ||||
) | ) | ||||
self._rlsn._store.get_events = Mock( | |||||
self._rlsn._store.get_events = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable({"123": mock_event}) | return_value=make_awaitable({"123": mock_event}) | ||||
) | ) | ||||
@@ -139,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
""" | """ | ||||
Test when user does not have blocked notice, but should have one | Test when user does not have blocked notice, but should have one | ||||
""" | """ | ||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( | |||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(None), | return_value=make_awaitable(None), | ||||
side_effect=ResourceLimitError(403, "foo"), | side_effect=ResourceLimitError(403, "foo"), | ||||
) | ) | ||||
@@ -152,7 +158,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
""" | """ | ||||
Test when user does not have blocked notice, nor should they (NOOP) | Test when user does not have blocked notice, nor should they (NOOP) | ||||
""" | """ | ||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( | |||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(None) | return_value=make_awaitable(None) | ||||
) | ) | ||||
@@ -165,7 +171,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
Test when user is not part of the MAU cohort - this should not ever | Test when user is not part of the MAU cohort - this should not ever | ||||
happen - but ... | happen - but ... | ||||
""" | """ | ||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( | |||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(None) | return_value=make_awaitable(None) | ||||
) | ) | ||||
self._rlsn._store.user_last_seen_monthly_active = Mock( | self._rlsn._store.user_last_seen_monthly_active = Mock( | ||||
@@ -183,7 +189,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
Test that when server is over MAU limit and alerting is suppressed, then | Test that when server is over MAU limit and alerting is suppressed, then | ||||
an alert message is not sent into the room | an alert message is not sent into the room | ||||
""" | """ | ||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( | |||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(None), | return_value=make_awaitable(None), | ||||
side_effect=ResourceLimitError( | side_effect=ResourceLimitError( | ||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER | 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER | ||||
@@ -198,7 +204,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
""" | """ | ||||
Test that when a server is disabled, that MAU limit alerting is ignored. | Test that when a server is disabled, that MAU limit alerting is ignored. | ||||
""" | """ | ||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( | |||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(None), | return_value=make_awaitable(None), | ||||
side_effect=ResourceLimitError( | side_effect=ResourceLimitError( | ||||
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED | 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED | ||||
@@ -217,21 +223,21 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||||
When the room is already in a blocked state, test that when alerting | When the room is already in a blocked state, test that when alerting | ||||
is suppressed that the room is returned to an unblocked state. | is suppressed that the room is returned to an unblocked state. | ||||
""" | """ | ||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( | |||||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable(None), | return_value=make_awaitable(None), | ||||
side_effect=ResourceLimitError( | side_effect=ResourceLimitError( | ||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER | 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER | ||||
), | ), | ||||
) | ) | ||||
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( | |||||
self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable((True, [])) | return_value=make_awaitable((True, [])) | ||||
) | ) | ||||
mock_event = Mock( | mock_event = Mock( | ||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} | type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} | ||||
) | ) | ||||
self._rlsn._store.get_events = Mock( | |||||
self._rlsn._store.get_events = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable({"123": mock_event}) | return_value=make_awaitable({"123": mock_event}) | ||||
) | ) | ||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | ||||
@@ -262,16 +268,18 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): | |||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
self.store = self.hs.get_datastores().main | self.store = self.hs.get_datastores().main | ||||
self.server_notices_sender = self.hs.get_server_notices_sender() | |||||
self.server_notices_manager = self.hs.get_server_notices_manager() | self.server_notices_manager = self.hs.get_server_notices_manager() | ||||
self.event_source = self.hs.get_event_sources() | self.event_source = self.hs.get_event_sources() | ||||
server_notices_sender = self.hs.get_server_notices_sender() | |||||
assert isinstance(server_notices_sender, ServerNoticesSender) | |||||
# relying on [1] is far from ideal, but the only case where | # relying on [1] is far from ideal, but the only case where | ||||
# ResourceLimitsServerNotices class needs to be isolated is this test, | # ResourceLimitsServerNotices class needs to be isolated is this test, | ||||
# general code should never have a reason to do so ... | # general code should never have a reason to do so ... | ||||
self._rlsn = self.server_notices_sender._server_notices[1] | |||||
if not isinstance(self._rlsn, ResourceLimitsServerNotices): | |||||
raise Exception("Failed to find reference to ResourceLimitsServerNotices") | |||||
rlsn = list(server_notices_sender._server_notices)[1] | |||||
assert isinstance(rlsn, ResourceLimitsServerNotices) | |||||
self._rlsn = rlsn | |||||
self.user_id = "@user_id:test" | self.user_id = "@user_id:test" | ||||
@@ -120,6 +120,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): | |||||
# Persist the event which should invalidate or prefill the | # Persist the event which should invalidate or prefill the | ||||
# `have_seen_event` cache so we don't return stale values. | # `have_seen_event` cache so we don't return stale values. | ||||
persistence = self.hs.get_storage_controllers().persistence | persistence = self.hs.get_storage_controllers().persistence | ||||
assert persistence is not None | |||||
self.get_success( | self.get_success( | ||||
persistence.persist_event( | persistence.persist_event( | ||||
event, | event, | ||||
@@ -389,6 +389,7 @@ class EventChainStoreTestCase(HomeserverTestCase): | |||||
""" | """ | ||||
persist_events_store = self.hs.get_datastores().persist_events | persist_events_store = self.hs.get_datastores().persist_events | ||||
assert persist_events_store is not None | |||||
for e in events: | for e in events: | ||||
e.internal_metadata.stream_ordering = self._next_stream_ordering | e.internal_metadata.stream_ordering = self._next_stream_ordering | ||||
@@ -397,6 +398,7 @@ class EventChainStoreTestCase(HomeserverTestCase): | |||||
def _persist(txn: LoggingTransaction) -> None: | def _persist(txn: LoggingTransaction) -> None: | ||||
# We need to persist the events to the events and state_events | # We need to persist the events to the events and state_events | ||||
# tables. | # tables. | ||||
assert persist_events_store is not None | |||||
persist_events_store._store_event_txn( | persist_events_store._store_event_txn( | ||||
txn, | txn, | ||||
[(e, EventContext(self.hs.get_storage_controllers())) for e in events], | [(e, EventContext(self.hs.get_storage_controllers())) for e in events], | ||||
@@ -540,7 +542,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): | |||||
self.requester, events_and_context=[(event, context)] | self.requester, events_and_context=[(event, context)] | ||||
) | ) | ||||
) | ) | ||||
state1 = set(self.get_success(context.get_current_state_ids()).values()) | |||||
state_ids1 = self.get_success(context.get_current_state_ids()) | |||||
assert state_ids1 is not None | |||||
state1 = set(state_ids1.values()) | |||||
event, context = self.get_success( | event, context = self.get_success( | ||||
event_handler.create_event( | event_handler.create_event( | ||||
@@ -560,7 +564,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): | |||||
self.requester, events_and_context=[(event, context)] | self.requester, events_and_context=[(event, context)] | ||||
) | ) | ||||
) | ) | ||||
state2 = set(self.get_success(context.get_current_state_ids()).values()) | |||||
state_ids2 = self.get_success(context.get_current_state_ids()) | |||||
assert state_ids2 is not None | |||||
state2 = set(state_ids2.values()) | |||||
# Delete the chain cover info. | # Delete the chain cover info. | ||||
@@ -54,6 +54,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
self.store = hs.get_datastores().main | self.store = hs.get_datastores().main | ||||
persist_events = hs.get_datastores().persist_events | |||||
assert persist_events is not None | |||||
self.persist_events = persist_events | |||||
def test_get_prev_events_for_room(self) -> None: | def test_get_prev_events_for_room(self) -> None: | ||||
room_id = "@ROOM:local" | room_id = "@ROOM:local" | ||||
@@ -226,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||||
}, | }, | ||||
) | ) | ||||
self.hs.datastores.persist_events._persist_event_auth_chain_txn( | |||||
self.persist_events._persist_event_auth_chain_txn( | |||||
txn, | txn, | ||||
[ | [ | ||||
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) | cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) | ||||
@@ -445,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||||
) | ) | ||||
# Insert all events apart from 'B' | # Insert all events apart from 'B' | ||||
self.hs.datastores.persist_events._persist_event_auth_chain_txn( | |||||
self.persist_events._persist_event_auth_chain_txn( | |||||
txn, | txn, | ||||
[ | [ | ||||
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) | cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) | ||||
@@ -464,7 +467,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): | |||||
updatevalues={"has_auth_chain_index": False}, | updatevalues={"has_auth_chain_index": False}, | ||||
) | ) | ||||
self.hs.datastores.persist_events._persist_event_auth_chain_txn( | |||||
self.persist_events._persist_event_auth_chain_txn( | |||||
txn, | txn, | ||||
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))], | [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))], | ||||
) | ) | ||||
@@ -40,7 +40,9 @@ class ExtremPruneTestCase(HomeserverTestCase): | |||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | ||||
) -> None: | ) -> None: | ||||
self.state = self.hs.get_state_handler() | self.state = self.hs.get_state_handler() | ||||
self._persistence = self.hs.get_storage_controllers().persistence | |||||
persistence = self.hs.get_storage_controllers().persistence | |||||
assert persistence is not None | |||||
self._persistence = persistence | |||||
self._state_storage_controller = self.hs.get_storage_controllers().state | self._state_storage_controller = self.hs.get_storage_controllers().state | ||||
self.store = self.hs.get_datastores().main | self.store = self.hs.get_datastores().main | ||||
@@ -374,7 +376,9 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): | |||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | ||||
) -> None: | ) -> None: | ||||
self.state = self.hs.get_state_handler() | self.state = self.hs.get_state_handler() | ||||
self._persistence = self.hs.get_storage_controllers().persistence | |||||
persistence = self.hs.get_storage_controllers().persistence | |||||
assert persistence is not None | |||||
self._persistence = persistence | |||||
self.store = self.hs.get_datastores().main | self.store = self.hs.get_datastores().main | ||||
def test_remote_user_rooms_cache_invalidated(self) -> None: | def test_remote_user_rooms_cache_invalidated(self) -> None: | ||||
@@ -16,8 +16,6 @@ import signedjson.key | |||||
import signedjson.types | import signedjson.types | ||||
import unpaddedbase64 | import unpaddedbase64 | ||||
from twisted.internet.defer import Deferred | |||||
from synapse.storage.keys import FetchKeyResult | from synapse.storage.keys import FetchKeyResult | ||||
import tests.unittest | import tests.unittest | ||||
@@ -44,20 +42,26 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): | |||||
key_id_1 = "ed25519:key1" | key_id_1 = "ed25519:key1" | ||||
key_id_2 = "ed25519:KEY_ID_2" | key_id_2 = "ed25519:KEY_ID_2" | ||||
d = store.store_server_verify_keys( | |||||
"from_server", | |||||
10, | |||||
[ | |||||
("server1", key_id_1, FetchKeyResult(KEY_1, 100)), | |||||
("server1", key_id_2, FetchKeyResult(KEY_2, 200)), | |||||
], | |||||
self.get_success( | |||||
store.store_server_verify_keys( | |||||
"from_server", | |||||
10, | |||||
[ | |||||
("server1", key_id_1, FetchKeyResult(KEY_1, 100)), | |||||
("server1", key_id_2, FetchKeyResult(KEY_2, 200)), | |||||
], | |||||
) | |||||
) | ) | ||||
self.get_success(d) | |||||
d = store.get_server_verify_keys( | |||||
[("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")] | |||||
res = self.get_success( | |||||
store.get_server_verify_keys( | |||||
[ | |||||
("server1", key_id_1), | |||||
("server1", key_id_2), | |||||
("server1", "ed25519:key3"), | |||||
] | |||||
) | |||||
) | ) | ||||
res = self.get_success(d) | |||||
self.assertEqual(len(res.keys()), 3) | self.assertEqual(len(res.keys()), 3) | ||||
res1 = res[("server1", key_id_1)] | res1 = res[("server1", key_id_1)] | ||||
@@ -82,18 +86,20 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): | |||||
key_id_1 = "ed25519:key1" | key_id_1 = "ed25519:key1" | ||||
key_id_2 = "ed25519:key2" | key_id_2 = "ed25519:key2" | ||||
d = store.store_server_verify_keys( | |||||
"from_server", | |||||
0, | |||||
[ | |||||
("srv1", key_id_1, FetchKeyResult(KEY_1, 100)), | |||||
("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), | |||||
], | |||||
self.get_success( | |||||
store.store_server_verify_keys( | |||||
"from_server", | |||||
0, | |||||
[ | |||||
("srv1", key_id_1, FetchKeyResult(KEY_1, 100)), | |||||
("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), | |||||
], | |||||
) | |||||
) | ) | ||||
self.get_success(d) | |||||
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) | |||||
res = self.get_success(d) | |||||
res = self.get_success( | |||||
store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) | |||||
) | |||||
self.assertEqual(len(res.keys()), 2) | self.assertEqual(len(res.keys()), 2) | ||||
res1 = res[("srv1", key_id_1)] | res1 = res[("srv1", key_id_1)] | ||||
@@ -105,9 +111,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): | |||||
self.assertEqual(res2.valid_until_ts, 200) | self.assertEqual(res2.valid_until_ts, 200) | ||||
# we should be able to look up the same thing again without a db hit | # we should be able to look up the same thing again without a db hit | ||||
res = store.get_server_verify_keys([("srv1", key_id_1)]) | |||||
if isinstance(res, Deferred): | |||||
res = self.successResultOf(res) | |||||
res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)])) | |||||
self.assertEqual(len(res.keys()), 1) | self.assertEqual(len(res.keys()), 1) | ||||
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) | self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) | ||||
@@ -119,8 +123,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): | |||||
) | ) | ||||
self.get_success(d) | self.get_success(d) | ||||
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) | |||||
res = self.get_success(d) | |||||
res = self.get_success( | |||||
store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) | |||||
) | |||||
self.assertEqual(len(res.keys()), 2) | self.assertEqual(len(res.keys()), 2) | ||||
res1 = res[("srv1", key_id_1)] | res1 = res[("srv1", key_id_1)] | ||||
@@ -112,7 +112,7 @@ class PurgeTests(HomeserverTestCase): | |||||
self.room_id, "m.room.create", "" | self.room_id, "m.room.create", "" | ||||
) | ) | ||||
) | ) | ||||
self.assertIsNotNone(create_event) | |||||
assert create_event is not None | |||||
# Purge everything before this topological token | # Purge everything before this topological token | ||||
self.get_success( | self.get_success( | ||||
@@ -37,9 +37,9 @@ class ReceiptTestCase(HomeserverTestCase): | |||||
self.store = homeserver.get_datastores().main | self.store = homeserver.get_datastores().main | ||||
self.room_creator = homeserver.get_room_creation_handler() | self.room_creator = homeserver.get_room_creation_handler() | ||||
self.persist_event_storage_controller = ( | |||||
self.hs.get_storage_controllers().persistence | |||||
) | |||||
persist_event_storage_controller = self.hs.get_storage_controllers().persistence | |||||
assert persist_event_storage_controller is not None | |||||
self.persist_event_storage_controller = persist_event_storage_controller | |||||
# Create a test user | # Create a test user | ||||
self.ourUser = UserID.from_string(OUR_USER_ID) | self.ourUser = UserID.from_string(OUR_USER_ID) | ||||
@@ -119,7 +119,6 @@ class EventSearchInsertionTest(HomeserverTestCase): | |||||
"content": {"msgtype": "m.text", "body": 2}, | "content": {"msgtype": "m.text", "body": 2}, | ||||
"room_id": room_id, | "room_id": room_id, | ||||
"sender": user_id, | "sender": user_id, | ||||
"depth": prev_event.depth + 1, | |||||
"prev_events": prev_event_ids, | "prev_events": prev_event_ids, | ||||
"origin_server_ts": self.clock.time_msec(), | "origin_server_ts": self.clock.time_msec(), | ||||
} | } | ||||
@@ -134,7 +133,7 @@ class EventSearchInsertionTest(HomeserverTestCase): | |||||
prev_state_map, | prev_state_map, | ||||
for_verification=False, | for_verification=False, | ||||
), | ), | ||||
depth=event_dict["depth"], | |||||
depth=prev_event.depth + 1, | |||||
) | ) | ||||
) | ) | ||||
@@ -16,7 +16,7 @@ from typing import List | |||||
from twisted.test.proto_helpers import MemoryReactor | from twisted.test.proto_helpers import MemoryReactor | ||||
from synapse.api.constants import EventTypes, RelationTypes | |||||
from synapse.api.constants import Direction, EventTypes, RelationTypes | |||||
from synapse.api.filtering import Filter | from synapse.api.filtering import Filter | ||||
from synapse.rest import admin | from synapse.rest import admin | ||||
from synapse.rest.client import login, room | from synapse.rest.client import login, room | ||||
@@ -128,7 +128,7 @@ class PaginationTestCase(HomeserverTestCase): | |||||
room_id=self.room_id, | room_id=self.room_id, | ||||
from_key=self.from_token.room_key, | from_key=self.from_token.room_key, | ||||
to_key=None, | to_key=None, | ||||
direction="f", | |||||
direction=Direction.FORWARDS, | |||||
limit=10, | limit=10, | ||||
event_filter=Filter(self.hs, filter), | event_filter=Filter(self.hs, filter), | ||||
) | ) | ||||
@@ -14,6 +14,7 @@ | |||||
from unittest.mock import MagicMock, patch | from unittest.mock import MagicMock, patch | ||||
from synapse.storage.database import make_conn | from synapse.storage.database import make_conn | ||||
from synapse.storage.engines import PostgresEngine | |||||
from synapse.storage.engines._base import IncorrectDatabaseSetup | from synapse.storage.engines._base import IncorrectDatabaseSetup | ||||
from tests.unittest import HomeserverTestCase | from tests.unittest import HomeserverTestCase | ||||
@@ -38,6 +39,7 @@ class UnsafeLocaleTest(HomeserverTestCase): | |||||
def test_safe_locale(self) -> None: | def test_safe_locale(self) -> None: | ||||
database = self.hs.get_datastores().databases[0] | database = self.hs.get_datastores().databases[0] | ||||
assert isinstance(database.engine, PostgresEngine) | |||||
db_conn = make_conn(database._database_config, database.engine, "test_unsafe") | db_conn = make_conn(database._database_config, database.engine, "test_unsafe") | ||||
with db_conn.cursor() as txn: | with db_conn.cursor() as txn: | ||||
@@ -12,17 +12,17 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Optional, Union | |||||
from typing import Collection, List, Optional, Union | |||||
from unittest.mock import Mock | from unittest.mock import Mock | ||||
from twisted.internet.defer import succeed | |||||
from twisted.test.proto_helpers import MemoryReactor | from twisted.test.proto_helpers import MemoryReactor | ||||
from synapse.api.errors import FederationError | from synapse.api.errors import FederationError | ||||
from synapse.api.room_versions import RoomVersions | |||||
from synapse.api.room_versions import RoomVersion, RoomVersions | |||||
from synapse.events import EventBase, make_event_from_dict | from synapse.events import EventBase, make_event_from_dict | ||||
from synapse.events.snapshot import EventContext | from synapse.events.snapshot import EventContext | ||||
from synapse.federation.federation_base import event_from_pdu_json | from synapse.federation.federation_base import event_from_pdu_json | ||||
from synapse.handlers.device import DeviceListUpdater | |||||
from synapse.http.types import QueryParams | from synapse.http.types import QueryParams | ||||
from synapse.logging.context import LoggingContext | from synapse.logging.context import LoggingContext | ||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
@@ -81,11 +81,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
) -> None: | ) -> None: | ||||
pass | pass | ||||
federation_event_handler._check_event_auth = _check_event_auth | |||||
federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment] | |||||
self.client = self.hs.get_federation_client() | self.client = self.hs.get_federation_client() | ||||
self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( | |||||
lambda dest, pdus, **k: succeed(pdus) | |||||
) | |||||
async def _check_sigs_and_hash_for_pulled_events_and_fetch( | |||||
dest: str, pdus: Collection[EventBase], room_version: RoomVersion | |||||
) -> List[EventBase]: | |||||
return list(pdus) | |||||
self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment] | |||||
# Send the join, it should return None (which is not an error) | # Send the join, it should return None (which is not an error) | ||||
self.assertEqual( | self.assertEqual( | ||||
@@ -187,7 +191,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
# Register the mock on the federation client. | # Register the mock on the federation client. | ||||
federation_client = self.hs.get_federation_client() | federation_client = self.hs.get_federation_client() | ||||
federation_client.query_user_devices = Mock(side_effect=query_user_devices) | |||||
federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment] | |||||
# Register a mock on the store so that the incoming update doesn't fail because | # Register a mock on the store so that the incoming update doesn't fail because | ||||
# we don't share a room with the user. | # we don't share a room with the user. | ||||
@@ -197,6 +201,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
# Manually inject a fake device list update. We need this update to include at | # Manually inject a fake device list update. We need this update to include at | ||||
# least one prev_id so that the user's device list will need to be retried. | # least one prev_id so that the user's device list will need to be retried. | ||||
device_list_updater = self.hs.get_device_handler().device_list_updater | device_list_updater = self.hs.get_device_handler().device_list_updater | ||||
assert isinstance(device_list_updater, DeviceListUpdater) | |||||
self.get_success( | self.get_success( | ||||
device_list_updater.incoming_device_list_update( | device_list_updater.incoming_device_list_update( | ||||
origin=remote_origin, | origin=remote_origin, | ||||
@@ -236,7 +241,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
# Register mock device list retrieval on the federation client. | # Register mock device list retrieval on the federation client. | ||||
federation_client = self.hs.get_federation_client() | federation_client = self.hs.get_federation_client() | ||||
federation_client.query_user_devices = Mock( | |||||
federation_client.query_user_devices = Mock( # type: ignore[assignment] | |||||
return_value=make_awaitable( | return_value=make_awaitable( | ||||
{ | { | ||||
"user_id": remote_user_id, | "user_id": remote_user_id, | ||||
@@ -269,16 +274,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||||
keys = self.get_success( | keys = self.get_success( | ||||
self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]), | self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]), | ||||
) | ) | ||||
self.assertTrue(remote_user_id in keys) | |||||
self.assertIn(remote_user_id, keys) | |||||
key = keys[remote_user_id] | |||||
assert key is not None | |||||
# Check that the master key is the one returned by the mock. | # Check that the master key is the one returned by the mock. | ||||
master_key = keys[remote_user_id]["master"] | |||||
master_key = key["master"] | |||||
self.assertEqual(len(master_key["keys"]), 1) | self.assertEqual(len(master_key["keys"]), 1) | ||||
self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys()) | self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys()) | ||||
self.assertTrue(remote_master_key in master_key["keys"].values()) | self.assertTrue(remote_master_key in master_key["keys"].values()) | ||||
# Check that the self-signing key is the one returned by the mock. | # Check that the self-signing key is the one returned by the mock. | ||||
self_signing_key = keys[remote_user_id]["self_signing"] | |||||
self_signing_key = key["self_signing"] | |||||
self.assertEqual(len(self_signing_key["keys"]), 1) | self.assertEqual(len(self_signing_key["keys"]), 1) | ||||
self.assertTrue( | self.assertTrue( | ||||
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), | "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), | ||||
@@ -33,7 +33,7 @@ class PhoneHomeStatsTestCase(HomeserverTestCase): | |||||
If time doesn't move, don't error out. | If time doesn't move, don't error out. | ||||
""" | """ | ||||
past_stats = [ | past_stats = [ | ||||
(self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF)) | |||||
(int(self.hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF)) | |||||
] | ] | ||||
stats: JsonDict = {} | stats: JsonDict = {} | ||||
self.get_success(phone_stats_home(self.hs, stats, past_stats)) | self.get_success(phone_stats_home(self.hs, stats, past_stats)) | ||||
@@ -35,6 +35,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||||
self.event_creation_handler = self.hs.get_event_creation_handler() | self.event_creation_handler = self.hs.get_event_creation_handler() | ||||
self.event_builder_factory = self.hs.get_event_builder_factory() | self.event_builder_factory = self.hs.get_event_builder_factory() | ||||
self._storage_controllers = self.hs.get_storage_controllers() | self._storage_controllers = self.hs.get_storage_controllers() | ||||
assert self._storage_controllers.persistence is not None | |||||
self._persistence = self._storage_controllers.persistence | |||||
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) | self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) | ||||
@@ -179,9 +181,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||||
self.event_creation_handler.create_new_client_event(builder) | self.event_creation_handler.create_new_client_event(builder) | ||||
) | ) | ||||
context = self.get_success(unpersisted_context.persist(event)) | context = self.get_success(unpersisted_context.persist(event)) | ||||
self.get_success( | |||||
self._storage_controllers.persistence.persist_event(event, context) | |||||
) | |||||
self.get_success(self._persistence.persist_event(event, context)) | |||||
return event | return event | ||||
def _inject_room_member( | def _inject_room_member( | ||||
@@ -208,9 +208,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
context = self.get_success(unpersisted_context.persist(event)) | context = self.get_success(unpersisted_context.persist(event)) | ||||
self.get_success( | |||||
self._storage_controllers.persistence.persist_event(event, context) | |||||
) | |||||
self.get_success(self._persistence.persist_event(event, context)) | |||||
return event | return event | ||||
def _inject_message( | def _inject_message( | ||||
@@ -233,9 +231,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
context = self.get_success(unpersisted_context.persist(event)) | context = self.get_success(unpersisted_context.persist(event)) | ||||
self.get_success( | |||||
self._storage_controllers.persistence.persist_event(event, context) | |||||
) | |||||
self.get_success(self._persistence.persist_event(event, context)) | |||||
return event | return event | ||||
def _inject_outlier(self) -> EventBase: | def _inject_outlier(self) -> EventBase: | ||||
@@ -253,7 +249,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): | |||||
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) | event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) | ||||
event.internal_metadata.outlier = True | event.internal_metadata.outlier = True | ||||
self.get_success( | self.get_success( | ||||
self._storage_controllers.persistence.persist_event( | |||||
self._persistence.persist_event( | |||||
event, EventContext.for_outlier(self._storage_controllers) | event, EventContext.for_outlier(self._storage_controllers) | ||||
) | ) | ||||
) | ) | ||||
@@ -361,7 +361,9 @@ class HomeserverTestCase(TestCase): | |||||
store.db_pool.updates.do_next_background_update(False), by=0.1 | store.db_pool.updates.do_next_background_update(False), by=0.1 | ||||
) | ) | ||||
def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock): | |||||
def make_homeserver( | |||||
self, reactor: ThreadedMemoryReactorClock, clock: Clock | |||||
) -> HomeServer: | |||||
""" | """ | ||||
Make and return a homeserver. | Make and return a homeserver. | ||||
@@ -54,6 +54,7 @@ class RetryLimiterTestCase(HomeserverTestCase): | |||||
self.pump() | self.pump() | ||||
new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) | new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) | ||||
assert new_timings is not None | |||||
self.assertEqual(new_timings.failure_ts, failure_ts) | self.assertEqual(new_timings.failure_ts, failure_ts) | ||||
self.assertEqual(new_timings.retry_last_ts, failure_ts) | self.assertEqual(new_timings.retry_last_ts, failure_ts) | ||||
self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL) | self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL) | ||||
@@ -82,6 +83,7 @@ class RetryLimiterTestCase(HomeserverTestCase): | |||||
self.pump() | self.pump() | ||||
new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) | new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) | ||||
assert new_timings is not None | |||||
self.assertEqual(new_timings.failure_ts, failure_ts) | self.assertEqual(new_timings.failure_ts, failure_ts) | ||||
self.assertEqual(new_timings.retry_last_ts, retry_ts) | self.assertEqual(new_timings.retry_last_ts, retry_ts) | ||||
self.assertGreaterEqual( | self.assertGreaterEqual( | ||||