Python 3.8 provides a native AsyncMock, we can replace the homegrown version we have.tags/v1.92.0rc1
@@ -0,0 +1 @@ | |||
Use `AsyncMock` instead of custom code. |
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
import time | |||
from typing import Any, Dict, List, Optional, cast | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
import attr | |||
import canonicaljson | |||
@@ -45,7 +45,6 @@ from synapse.types import JsonDict | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import logcontext_clean, override_config | |||
@@ -291,7 +290,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||
with a null `ts_valid_until_ms` | |||
""" | |||
mock_fetcher = Mock() | |||
mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) | |||
mock_fetcher.get_keys = AsyncMock(return_value={}) | |||
key1 = signedjson.key.generate_signing_key("1") | |||
r = self.hs.get_datastores().main.store_server_signature_keys( | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock | |||
from synapse.api.errors import Codes, SynapseError | |||
from synapse.rest import admin | |||
@@ -20,7 +20,6 @@ from synapse.rest.client import login, room | |||
from synapse.types import JsonDict, UserID, create_requester | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
class RoomComplexityTests(unittest.FederatingHomeserverTestCase): | |||
@@ -75,9 +74,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): | |||
fed_transport = self.hs.get_federation_transport_client() | |||
# Mock out some things, because we don't want to test the whole join | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(("", 1)) | |||
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment] | |||
return_value=("", 1) | |||
) | |||
d = handler._remote_join( | |||
@@ -106,9 +105,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): | |||
fed_transport = self.hs.get_federation_transport_client() | |||
# Mock out some things, because we don't want to test the whole join | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(("", 1)) | |||
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment] | |||
return_value=("", 1) | |||
) | |||
d = handler._remote_join( | |||
@@ -143,9 +142,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): | |||
fed_transport = self.hs.get_federation_transport_client() | |||
# Mock out some things, because we don't want to test the whole join | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(("", 1)) | |||
fed_transport.client.get_json = AsyncMock(return_value=None) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment] | |||
return_value=("", 1) | |||
) | |||
# Artificially raise the complexity | |||
@@ -200,9 +199,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): | |||
fed_transport = self.hs.get_federation_transport_client() | |||
# Mock out some things, because we don't want to test the whole join | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(("", 1)) | |||
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment] | |||
return_value=("", 1) | |||
) | |||
d = handler._remote_join( | |||
@@ -230,9 +229,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): | |||
fed_transport = self.hs.get_federation_transport_client() | |||
# Mock out some things, because we don't want to test the whole join | |||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(("", 1)) | |||
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment] | |||
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment] | |||
return_value=("", 1) | |||
) | |||
d = handler._remote_join( | |||
@@ -1,6 +1,6 @@ | |||
from typing import Callable, Collection, List, Optional, Tuple | |||
from unittest import mock | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -19,7 +19,7 @@ from synapse.types import JsonDict | |||
from synapse.util import Clock | |||
from synapse.util.retryutils import NotRetryingDestination | |||
from tests.test_utils import event_injection, make_awaitable | |||
from tests.test_utils import event_injection | |||
from tests.unittest import FederatingHomeserverTestCase | |||
@@ -50,8 +50,8 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): | |||
# This mock is crucial for destination_rooms to be populated. | |||
# TODO: this seems to no longer be the case---tests pass with this mock | |||
# commented out. | |||
state_storage_controller.get_current_hosts_in_room = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable({"test", "host2"}) | |||
state_storage_controller.get_current_hosts_in_room = AsyncMock( # type: ignore[assignment] | |||
return_value={"test", "host2"} | |||
) | |||
# whenever send_transaction is called, record the pdu data | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Callable, FrozenSet, List, Optional, Set | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from signedjson import key, sign | |||
from signedjson.types import BaseKey, SigningKey | |||
@@ -29,7 +29,6 @@ from synapse.server import HomeServer | |||
from synapse.types import JsonDict, ReadReceipt | |||
from synapse.util import Clock | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import HomeserverTestCase | |||
@@ -43,12 +42,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
self.federation_transport_client = Mock(spec=["send_transaction"]) | |||
self.federation_transport_client.send_transaction = AsyncMock() | |||
hs = self.setup_test_homeserver( | |||
federation_transport_client=self.federation_transport_client, | |||
) | |||
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable({"test", "host2"}) | |||
hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[assignment] | |||
return_value={"test", "host2"} | |||
) | |||
hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment] | |||
@@ -64,7 +64,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): | |||
def test_send_receipts(self) -> None: | |||
mock_send_transaction = self.federation_transport_client.send_transaction | |||
mock_send_transaction.return_value = make_awaitable({}) | |||
mock_send_transaction.return_value = {} | |||
sender = self.hs.get_federation_sender() | |||
receipt = ReadReceipt( | |||
@@ -104,7 +104,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): | |||
def test_send_receipts_thread(self) -> None: | |||
mock_send_transaction = self.federation_transport_client.send_transaction | |||
mock_send_transaction.return_value = make_awaitable({}) | |||
mock_send_transaction.return_value = {} | |||
# Create receipts for: | |||
# | |||
@@ -180,7 +180,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): | |||
"""Send two receipts in quick succession; the second should be flushed, but | |||
only after 20ms""" | |||
mock_send_transaction = self.federation_transport_client.send_transaction | |||
mock_send_transaction.return_value = make_awaitable({}) | |||
mock_send_transaction.return_value = {} | |||
sender = self.hs.get_federation_sender() | |||
receipt = ReadReceipt( | |||
@@ -276,6 +276,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||
self.federation_transport_client = Mock( | |||
spec=["send_transaction", "query_user_devices"] | |||
) | |||
self.federation_transport_client.send_transaction = AsyncMock() | |||
self.federation_transport_client.query_user_devices = AsyncMock() | |||
return self.setup_test_homeserver( | |||
federation_transport_client=self.federation_transport_client, | |||
) | |||
@@ -317,13 +319,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||
self.record_transaction | |||
) | |||
def record_transaction( | |||
async def record_transaction( | |||
self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None | |||
) -> "defer.Deferred[JsonDict]": | |||
) -> JsonDict: | |||
assert json_cb is not None | |||
data = json_cb() | |||
self.edus.extend(data["edus"]) | |||
return defer.succeed({}) | |||
return {} | |||
def test_send_device_updates(self) -> None: | |||
"""Basic case: each device update should result in an EDU""" | |||
@@ -354,15 +356,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||
# Send the server a device list EDU for the other user, this will cause | |||
# it to try and resync the device lists. | |||
self.federation_transport_client.query_user_devices.return_value = ( | |||
make_awaitable( | |||
{ | |||
"stream_id": "1", | |||
"user_id": "@user2:host2", | |||
"devices": [{"device_id": "D1"}], | |||
} | |||
) | |||
) | |||
self.federation_transport_client.query_user_devices.return_value = { | |||
"stream_id": "1", | |||
"user_id": "@user2:host2", | |||
"devices": [{"device_id": "D1"}], | |||
} | |||
self.get_success( | |||
self.device_handler.device_list_updater.incoming_device_list_update( | |||
@@ -533,7 +531,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||
recovery | |||
""" | |||
mock_send_txn = self.federation_transport_client.send_transaction | |||
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) | |||
mock_send_txn.side_effect = AssertionError("fail") | |||
# create devices | |||
u1 = self.register_user("user", "pass") | |||
@@ -578,7 +576,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||
This case tests the behaviour when the server has never been reachable. | |||
""" | |||
mock_send_txn = self.federation_transport_client.send_transaction | |||
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) | |||
mock_send_txn.side_effect = AssertionError("fail") | |||
# create devices | |||
u1 = self.register_user("user", "pass") | |||
@@ -636,7 +634,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): | |||
# now the server goes offline | |||
mock_send_txn = self.federation_transport_client.send_transaction | |||
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) | |||
mock_send_txn.side_effect = AssertionError("fail") | |||
self.login("user", "pass", device_id="D2") | |||
self.login("user", "pass", device_id="D3") | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
from typing import Dict, Iterable, List, Optional | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from parameterized import parameterized | |||
@@ -36,7 +36,7 @@ from synapse.util import Clock | |||
from synapse.util.stringutils import random_string | |||
from tests import unittest | |||
from tests.test_utils import event_injection, make_awaitable, simple_async_mock | |||
from tests.test_utils import event_injection, simple_async_mock | |||
from tests.unittest import override_config | |||
from tests.utils import MockClock | |||
@@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.mock_store = Mock() | |||
self.mock_as_api = Mock() | |||
self.mock_as_api = AsyncMock() | |||
self.mock_scheduler = Mock() | |||
hs = Mock() | |||
hs.get_datastores.return_value = Mock(main=self.mock_store) | |||
self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None) | |||
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) | |||
self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( | |||
None | |||
) | |||
self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None) | |||
self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None) | |||
self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None) | |||
hs.get_application_service_api.return_value = self.mock_as_api | |||
hs.get_application_service_scheduler.return_value = self.mock_scheduler | |||
hs.get_clock.return_value = MockClock() | |||
@@ -69,21 +67,25 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
self._mkservice(is_interested_in_event=False), | |||
] | |||
self.mock_as_api.query_user.return_value = make_awaitable(True) | |||
self.mock_as_api.query_user.return_value = True | |||
self.mock_store.get_app_services.return_value = services | |||
self.mock_store.get_user_by_id.return_value = make_awaitable([]) | |||
self.mock_store.get_user_by_id = AsyncMock(return_value=[]) | |||
event = Mock( | |||
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" | |||
) | |||
self.mock_store.get_all_new_event_ids_stream.side_effect = [ | |||
make_awaitable((0, {})), | |||
make_awaitable((1, {event.event_id: 0})), | |||
] | |||
self.mock_store.get_events_as_list.side_effect = [ | |||
make_awaitable([]), | |||
make_awaitable([event]), | |||
] | |||
self.mock_store.get_all_new_event_ids_stream = AsyncMock( | |||
side_effect=[ | |||
(0, {}), | |||
(1, {event.event_id: 0}), | |||
] | |||
) | |||
self.mock_store.get_events_as_list = AsyncMock( | |||
side_effect=[ | |||
[], | |||
[event], | |||
] | |||
) | |||
self.handler.notify_interested_services(RoomStreamToken(None, 1)) | |||
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( | |||
@@ -95,14 +97,16 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
services = [self._mkservice(is_interested_in_event=True)] | |||
services[0].is_interested_in_user.return_value = True | |||
self.mock_store.get_app_services.return_value = services | |||
self.mock_store.get_user_by_id.return_value = make_awaitable(None) | |||
self.mock_store.get_user_by_id = AsyncMock(return_value=None) | |||
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") | |||
self.mock_as_api.query_user.return_value = make_awaitable(True) | |||
self.mock_store.get_all_new_event_ids_stream.side_effect = [ | |||
make_awaitable((0, {event.event_id: 0})), | |||
] | |||
self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])] | |||
self.mock_as_api.query_user.return_value = True | |||
self.mock_store.get_all_new_event_ids_stream = AsyncMock( | |||
side_effect=[ | |||
(0, {event.event_id: 0}), | |||
] | |||
) | |||
self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]]) | |||
self.handler.notify_interested_services(RoomStreamToken(None, 0)) | |||
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) | |||
@@ -112,13 +116,15 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
services = [self._mkservice(is_interested_in_event=True)] | |||
services[0].is_interested_in_user.return_value = True | |||
self.mock_store.get_app_services.return_value = services | |||
self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id}) | |||
self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id}) | |||
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") | |||
self.mock_as_api.query_user.return_value = make_awaitable(True) | |||
self.mock_store.get_all_new_event_ids_stream.side_effect = [ | |||
make_awaitable((0, [event], {event.event_id: 0})), | |||
] | |||
self.mock_as_api.query_user.return_value = True | |||
self.mock_store.get_all_new_event_ids_stream = AsyncMock( | |||
side_effect=[ | |||
(0, [event], {event.event_id: 0}), | |||
] | |||
) | |||
self.handler.notify_interested_services(RoomStreamToken(None, 0)) | |||
@@ -141,10 +147,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
self._mkservice_alias(is_room_alias_in_namespace=False), | |||
] | |||
self.mock_as_api.query_alias.return_value = make_awaitable(True) | |||
self.mock_as_api.query_alias = AsyncMock(return_value=True) | |||
self.mock_store.get_app_services.return_value = services | |||
self.mock_store.get_association_from_room_alias.return_value = make_awaitable( | |||
Mock(room_id=room_id, servers=servers) | |||
self.mock_store.get_association_from_room_alias = AsyncMock( | |||
return_value=Mock(room_id=room_id, servers=servers) | |||
) | |||
result = self.successResultOf( | |||
@@ -177,7 +183,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
def test_get_3pe_protocols_protocol_no_response(self) -> None: | |||
service = self._mkservice(False, ["my-protocol"]) | |||
self.mock_store.get_app_services.return_value = [service] | |||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) | |||
self.mock_as_api.get_3pe_protocol.return_value = None | |||
response = self.successResultOf( | |||
defer.ensureDeferred(self.handler.get_3pe_protocols()) | |||
) | |||
@@ -189,9 +195,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
def test_get_3pe_protocols_select_one_protocol(self) -> None: | |||
service = self._mkservice(False, ["my-protocol"]) | |||
self.mock_store.get_app_services.return_value = [service] | |||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( | |||
{"x-protocol-data": 42, "instances": []} | |||
) | |||
self.mock_as_api.get_3pe_protocol.return_value = { | |||
"x-protocol-data": 42, | |||
"instances": [], | |||
} | |||
response = self.successResultOf( | |||
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) | |||
) | |||
@@ -205,9 +212,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
def test_get_3pe_protocols_one_protocol(self) -> None: | |||
service = self._mkservice(False, ["my-protocol"]) | |||
self.mock_store.get_app_services.return_value = [service] | |||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( | |||
{"x-protocol-data": 42, "instances": []} | |||
) | |||
self.mock_as_api.get_3pe_protocol.return_value = { | |||
"x-protocol-data": 42, | |||
"instances": [], | |||
} | |||
response = self.successResultOf( | |||
defer.ensureDeferred(self.handler.get_3pe_protocols()) | |||
) | |||
@@ -222,9 +230,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
service_one = self._mkservice(False, ["my-protocol"]) | |||
service_two = self._mkservice(False, ["other-protocol"]) | |||
self.mock_store.get_app_services.return_value = [service_one, service_two] | |||
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( | |||
{"x-protocol-data": 42, "instances": []} | |||
) | |||
self.mock_as_api.get_3pe_protocol.return_value = { | |||
"x-protocol-data": 42, | |||
"instances": [], | |||
} | |||
response = self.successResultOf( | |||
defer.ensureDeferred(self.handler.get_3pe_protocols()) | |||
) | |||
@@ -287,13 +296,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
interested_service = self._mkservice(is_interested_in_event=True) | |||
services = [interested_service] | |||
self.mock_store.get_app_services.return_value = services | |||
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( | |||
579 | |||
) | |||
self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579) | |||
event = Mock(event_id="event_1") | |||
self.event_source.sources.receipt.get_new_events_as.return_value = ( | |||
make_awaitable(([event], None)) | |||
self.event_source.sources.receipt.get_new_events_as = AsyncMock( | |||
return_value=([event], None) | |||
) | |||
self.handler.notify_interested_services_ephemeral( | |||
@@ -317,13 +324,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
services = [interested_service] | |||
self.mock_store.get_app_services.return_value = services | |||
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( | |||
580 | |||
) | |||
self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580) | |||
event = Mock(event_id="event_1") | |||
self.event_source.sources.receipt.get_new_events_as.return_value = ( | |||
make_awaitable(([event], None)) | |||
self.event_source.sources.receipt.get_new_events_as = AsyncMock( | |||
return_value=([event], None) | |||
) | |||
self.handler.notify_interested_services_ephemeral( | |||
@@ -350,9 +355,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
A mock representing the ApplicationService. | |||
""" | |||
service = Mock() | |||
service.is_interested_in_event.return_value = make_awaitable( | |||
is_interested_in_event | |||
) | |||
service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event) | |||
service.token = "mock_service_token" | |||
service.url = "mock_service_url" | |||
service.protocols = protocols | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Optional | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock | |||
import pymacaroons | |||
@@ -25,7 +25,6 @@ from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
class AuthTestCase(unittest.HomeserverTestCase): | |||
@@ -166,8 +165,8 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||
def test_mau_limits_exceeded_large(self) -> None: | |||
self.auth_blocking._limit_usage_by_mau = True | |||
self.hs.get_datastores().main.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.large_number_of_users) | |||
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( | |||
return_value=self.large_number_of_users | |||
) | |||
self.get_failure( | |||
@@ -177,8 +176,8 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||
ResourceLimitError, | |||
) | |||
self.hs.get_datastores().main.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.large_number_of_users) | |||
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( | |||
return_value=self.large_number_of_users | |||
) | |||
token = self.get_success( | |||
self.auth_handler.create_login_token_for_user_id(self.user1) | |||
@@ -191,8 +190,8 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||
self.auth_blocking._limit_usage_by_mau = True | |||
# Set the server to be at the edge of too many users. | |||
self.hs.get_datastores().main.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.auth_blocking._max_mau_value) | |||
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( | |||
return_value=self.auth_blocking._max_mau_value | |||
) | |||
# If not in monthly active cohort | |||
@@ -208,8 +207,8 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||
self.assertIsNone(self.token_login(token)) | |||
# If in monthly active cohort | |||
self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( | |||
return_value=make_awaitable(self.clock.time_msec()) | |||
self.hs.get_datastores().main.user_last_seen_monthly_active = AsyncMock( | |||
return_value=self.clock.time_msec() | |||
) | |||
self.get_success( | |||
self.auth_handler.create_access_token_for_user_id( | |||
@@ -224,8 +223,8 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||
def test_mau_limits_not_exceeded(self) -> None: | |||
self.auth_blocking._limit_usage_by_mau = True | |||
self.hs.get_datastores().main.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.small_number_of_users) | |||
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( | |||
return_value=self.small_number_of_users | |||
) | |||
# Ensure does not raise exception | |||
self.get_success( | |||
@@ -234,8 +233,8 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
self.hs.get_datastores().main.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.small_number_of_users) | |||
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( | |||
return_value=self.small_number_of_users | |||
) | |||
token = self.get_success( | |||
self.auth_handler.create_login_token_for_user_id(self.user1) | |||
@@ -32,7 +32,6 @@ from synapse.types import JsonDict, create_requester | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import override_config | |||
user1 = "@boris:aaa" | |||
@@ -41,7 +40,7 @@ user2 = "@theresa:bbb" | |||
class DeviceTestCase(unittest.HomeserverTestCase): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
self.appservice_api = mock.Mock() | |||
self.appservice_api = mock.AsyncMock() | |||
hs = self.setup_test_homeserver( | |||
"server", | |||
application_service_api=self.appservice_api, | |||
@@ -375,13 +374,11 @@ class DeviceTestCase(unittest.HomeserverTestCase): | |||
) | |||
# Setup a response. | |||
self.appservice_api.query_keys.return_value = make_awaitable( | |||
{ | |||
"device_keys": { | |||
local_user: {device_2: device_key_2b, device_3: device_key_3} | |||
} | |||
self.appservice_api.query_keys.return_value = { | |||
"device_keys": { | |||
local_user: {device_2: device_key_2b, device_3: device_key_3} | |||
} | |||
) | |||
} | |||
# Request all devices. | |||
res = self.get_success( | |||
@@ -13,7 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Awaitable, Callable, Dict | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -27,14 +27,13 @@ from synapse.types import JsonDict, RoomAlias, create_requester | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
class DirectoryTestCase(unittest.HomeserverTestCase): | |||
"""Tests the directory service.""" | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
self.mock_federation = Mock() | |||
self.mock_federation = AsyncMock() | |||
self.mock_registry = Mock() | |||
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} | |||
@@ -73,9 +72,10 @@ class DirectoryTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) | |||
def test_get_remote_association(self) -> None: | |||
self.mock_federation.make_query.return_value = make_awaitable( | |||
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]} | |||
) | |||
self.mock_federation.make_query.return_value = { | |||
"room_id": "!8765qwer:test", | |||
"servers": ["test", "remote"], | |||
} | |||
result = self.get_success(self.handler.get_association(self.remote_room)) | |||
@@ -13,7 +13,7 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Iterable | |||
from typing import Dict, Iterable | |||
from unittest import mock | |||
from parameterized import parameterized | |||
@@ -31,13 +31,12 @@ from synapse.types import JsonDict, UserID | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import override_config | |||
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
self.appservice_api = mock.Mock() | |||
self.appservice_api = mock.AsyncMock() | |||
return self.setup_test_homeserver( | |||
federation_client=mock.Mock(), application_service_api=self.appservice_api | |||
) | |||
@@ -801,29 +800,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" | |||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" | |||
self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment] | |||
return_value=make_awaitable( | |||
{ | |||
"device_keys": {remote_user_id: {}}, | |||
"master_keys": { | |||
remote_user_id: { | |||
"user_id": remote_user_id, | |||
"usage": ["master"], | |||
"keys": {"ed25519:" + remote_master_key: remote_master_key}, | |||
}, | |||
}, | |||
"self_signing_keys": { | |||
remote_user_id: { | |||
"user_id": remote_user_id, | |||
"usage": ["self_signing"], | |||
"keys": { | |||
"ed25519:" | |||
+ remote_self_signing_key: remote_self_signing_key | |||
}, | |||
} | |||
self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[assignment] | |||
return_value={ | |||
"device_keys": {remote_user_id: {}}, | |||
"master_keys": { | |||
remote_user_id: { | |||
"user_id": remote_user_id, | |||
"usage": ["master"], | |||
"keys": {"ed25519:" + remote_master_key: remote_master_key}, | |||
}, | |||
} | |||
) | |||
}, | |||
"self_signing_keys": { | |||
remote_user_id: { | |||
"user_id": remote_user_id, | |||
"usage": ["self_signing"], | |||
"keys": { | |||
"ed25519:" | |||
+ remote_self_signing_key: remote_self_signing_key | |||
}, | |||
} | |||
}, | |||
} | |||
) | |||
e2e_handler = self.hs.get_e2e_keys_handler() | |||
@@ -874,34 +871,29 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
# Pretend we're sharing a room with the user we're querying. If not, | |||
# `_query_devices_for_destination` will return early. | |||
self.store.get_rooms_for_user = mock.Mock( | |||
return_value=make_awaitable({"some_room_id"}) | |||
) | |||
self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"}) | |||
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" | |||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" | |||
self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment] | |||
return_value=make_awaitable( | |||
{ | |||
self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[assignment] | |||
return_value={ | |||
"user_id": remote_user_id, | |||
"stream_id": 1, | |||
"devices": [], | |||
"master_key": { | |||
"user_id": remote_user_id, | |||
"stream_id": 1, | |||
"devices": [], | |||
"master_key": { | |||
"user_id": remote_user_id, | |||
"usage": ["master"], | |||
"keys": {"ed25519:" + remote_master_key: remote_master_key}, | |||
}, | |||
"self_signing_key": { | |||
"user_id": remote_user_id, | |||
"usage": ["self_signing"], | |||
"keys": { | |||
"ed25519:" | |||
+ remote_self_signing_key: remote_self_signing_key | |||
}, | |||
"usage": ["master"], | |||
"keys": {"ed25519:" + remote_master_key: remote_master_key}, | |||
}, | |||
"self_signing_key": { | |||
"user_id": remote_user_id, | |||
"usage": ["self_signing"], | |||
"keys": { | |||
"ed25519:" + remote_self_signing_key: remote_self_signing_key | |||
}, | |||
} | |||
) | |||
}, | |||
} | |||
) | |||
e2e_handler = self.hs.get_e2e_keys_handler() | |||
@@ -987,20 +979,20 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
mock_get_rooms = mock.patch.object( | |||
self.store, | |||
"get_rooms_for_user", | |||
new_callable=mock.MagicMock, | |||
return_value=make_awaitable(["some_room_id"]), | |||
new_callable=mock.AsyncMock, | |||
return_value=["some_room_id"], | |||
) | |||
mock_get_users = mock.patch.object( | |||
self.store, | |||
"get_users_server_still_shares_room_with", | |||
new_callable=mock.MagicMock, | |||
return_value=make_awaitable({remote_user_id}), | |||
new_callable=mock.AsyncMock, | |||
return_value={remote_user_id}, | |||
) | |||
mock_request = mock.patch.object( | |||
self.hs.get_federation_client(), | |||
"query_user_devices", | |||
new_callable=mock.MagicMock, | |||
return_value=make_awaitable(response_body), | |||
new_callable=mock.AsyncMock, | |||
return_value=response_body, | |||
) | |||
with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request: | |||
@@ -1060,8 +1052,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
) | |||
# Setup a response, but only for device 2. | |||
self.appservice_api.claim_client_keys.return_value = make_awaitable( | |||
({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)]) | |||
self.appservice_api.claim_client_keys.return_value = ( | |||
{local_user: {device_id_2: otk}}, | |||
[(local_user, device_id_1, "alg1", 1)], | |||
) | |||
# we shouldn't have any unused fallback keys yet | |||
@@ -1127,9 +1120,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
) | |||
# Setup a response. | |||
self.appservice_api.claim_client_keys.return_value = make_awaitable( | |||
({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, []) | |||
) | |||
response: Dict[str, Dict[str, Dict[str, JsonDict]]] = { | |||
local_user: {device_id_1: {**as_otk, **as_fallback_key}} | |||
} | |||
self.appservice_api.claim_client_keys.return_value = (response, []) | |||
# Claim OTKs, which will ask the appservice and do nothing else. | |||
claim_res = self.get_success( | |||
@@ -1171,8 +1165,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(fallback_res, ["alg1"]) | |||
# The appservice will return only the OTK. | |||
self.appservice_api.claim_client_keys.return_value = make_awaitable( | |||
({local_user: {device_id_1: as_otk}}, []) | |||
self.appservice_api.claim_client_keys.return_value = ( | |||
{local_user: {device_id_1: as_otk}}, | |||
[], | |||
) | |||
# Claim OTKs, which should return the OTK from the appservice and the | |||
@@ -1234,8 +1229,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(fallback_res, ["alg1"]) | |||
# Finally, return only the fallback key from the appservice. | |||
self.appservice_api.claim_client_keys.return_value = make_awaitable( | |||
({local_user: {device_id_1: as_fallback_key}}, []) | |||
self.appservice_api.claim_client_keys.return_value = ( | |||
{local_user: {device_id_1: as_fallback_key}}, | |||
[], | |||
) | |||
# Claim OTKs, which will return only the fallback key from the database. | |||
@@ -1350,13 +1346,11 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
) | |||
# Setup a response. | |||
self.appservice_api.query_keys.return_value = make_awaitable( | |||
{ | |||
"device_keys": { | |||
local_user: {device_2: device_key_2b, device_3: device_key_3} | |||
} | |||
self.appservice_api.query_keys.return_value = { | |||
"device_keys": { | |||
local_user: {device_2: device_key_2b, device_3: device_key_3} | |||
} | |||
) | |||
} | |||
# Request all devices. | |||
res = self.get_success(self.handler.query_local_devices({local_user: None})) | |||
@@ -14,7 +14,7 @@ | |||
import logging | |||
from typing import Collection, Optional, cast | |||
from unittest import TestCase | |||
from unittest.mock import Mock, patch | |||
from unittest.mock import AsyncMock, Mock, patch | |||
from twisted.internet.defer import Deferred | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -40,7 +40,7 @@ from synapse.util import Clock | |||
from synapse.util.stringutils import random_string | |||
from tests import unittest | |||
from tests.test_utils import event_injection, make_awaitable | |||
from tests.test_utils import event_injection | |||
logger = logging.getLogger(__name__) | |||
@@ -370,7 +370,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): | |||
# We mock out the FederationClient.backfill method, to pretend that a remote | |||
# server has returned our fake event. | |||
federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) | |||
federation_client_backfill_mock = AsyncMock(return_value=[event]) | |||
self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment] | |||
# We also mock the persist method with a side effect of itself. This allows us | |||
@@ -631,33 +631,29 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): | |||
}, | |||
RoomVersions.V10, | |||
) | |||
mock_make_membership_event = Mock( | |||
return_value=make_awaitable( | |||
( | |||
"example.com", | |||
membership_event, | |||
RoomVersions.V10, | |||
) | |||
mock_make_membership_event = AsyncMock( | |||
return_value=( | |||
"example.com", | |||
membership_event, | |||
RoomVersions.V10, | |||
) | |||
) | |||
mock_send_join = Mock( | |||
return_value=make_awaitable( | |||
SendJoinResult( | |||
membership_event, | |||
"example.com", | |||
state=[ | |||
EVENT_CREATE, | |||
EVENT_CREATOR_MEMBERSHIP, | |||
EVENT_INVITATION_MEMBERSHIP, | |||
], | |||
auth_chain=[ | |||
EVENT_CREATE, | |||
EVENT_CREATOR_MEMBERSHIP, | |||
EVENT_INVITATION_MEMBERSHIP, | |||
], | |||
partial_state=True, | |||
servers_in_room={"example.com"}, | |||
) | |||
mock_send_join = AsyncMock( | |||
return_value=SendJoinResult( | |||
membership_event, | |||
"example.com", | |||
state=[ | |||
EVENT_CREATE, | |||
EVENT_CREATOR_MEMBERSHIP, | |||
EVENT_INVITATION_MEMBERSHIP, | |||
], | |||
auth_chain=[ | |||
EVENT_CREATE, | |||
EVENT_CREATOR_MEMBERSHIP, | |||
EVENT_INVITATION_MEMBERSHIP, | |||
], | |||
partial_state=True, | |||
servers_in_room={"example.com"}, | |||
) | |||
) | |||
@@ -35,7 +35,7 @@ from synapse.types import JsonDict | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import event_injection, make_awaitable | |||
from tests.test_utils import event_injection | |||
class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||
@@ -50,6 +50,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||
self.mock_federation_transport_client = mock.Mock( | |||
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] | |||
) | |||
self.mock_federation_transport_client.get_room_state_ids = mock.AsyncMock() | |||
self.mock_federation_transport_client.get_room_state = mock.AsyncMock() | |||
self.mock_federation_transport_client.get_event = mock.AsyncMock() | |||
self.mock_federation_transport_client.backfill = mock.AsyncMock() | |||
return super().setup_test_homeserver( | |||
federation_transport_client=self.mock_federation_transport_client | |||
) | |||
@@ -198,20 +202,14 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||
) | |||
# we expect an outbound request to /state_ids, so stub that out | |||
self.mock_federation_transport_client.get_room_state_ids.return_value = ( | |||
make_awaitable( | |||
{ | |||
"pdu_ids": [e.event_id for e in state_at_prev_event], | |||
"auth_chain_ids": [], | |||
} | |||
) | |||
) | |||
self.mock_federation_transport_client.get_room_state_ids.return_value = { | |||
"pdu_ids": [e.event_id for e in state_at_prev_event], | |||
"auth_chain_ids": [], | |||
} | |||
# we also expect an outbound request to /state | |||
self.mock_federation_transport_client.get_room_state.return_value = ( | |||
make_awaitable( | |||
StateRequestResponse(auth_events=[], state=state_at_prev_event) | |||
) | |||
StateRequestResponse(auth_events=[], state=state_at_prev_event) | |||
) | |||
# we have to bump the clock a bit, to keep the retry logic in | |||
@@ -273,26 +271,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||
room_version = self.get_success(main_store.get_room_version(room_id)) | |||
# We expect an outbound request to /state_ids, so stub that out | |||
self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable( | |||
{ | |||
# Mimic the other server not knowing about the state at all. | |||
# We want to cause Synapse to throw an error (`Unable to get | |||
# missing prev_event $fake_prev_event`) and fail to backfill | |||
# the pulled event. | |||
"pdu_ids": [], | |||
"auth_chain_ids": [], | |||
} | |||
) | |||
self.mock_federation_transport_client.get_room_state_ids.return_value = { | |||
# Mimic the other server not knowing about the state at all. | |||
# We want to cause Synapse to throw an error (`Unable to get | |||
# missing prev_event $fake_prev_event`) and fail to backfill | |||
# the pulled event. | |||
"pdu_ids": [], | |||
"auth_chain_ids": [], | |||
} | |||
# We also expect an outbound request to /state | |||
self.mock_federation_transport_client.get_room_state.return_value = make_awaitable( | |||
StateRequestResponse( | |||
# Mimic the other server not knowing about the state at all. | |||
# We want to cause Synapse to throw an error (`Unable to get | |||
# missing prev_event $fake_prev_event`) and fail to backfill | |||
# the pulled event. | |||
auth_events=[], | |||
state=[], | |||
) | |||
self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse( | |||
# Mimic the other server not knowing about the state at all. | |||
# We want to cause Synapse to throw an error (`Unable to get | |||
# missing prev_event $fake_prev_event`) and fail to backfill | |||
# the pulled event. | |||
auth_events=[], | |||
state=[], | |||
) | |||
pulled_event = make_event_from_dict( | |||
@@ -545,25 +540,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||
) | |||
# We expect an outbound request to /backfill, so stub that out | |||
self.mock_federation_transport_client.backfill.return_value = make_awaitable( | |||
{ | |||
"origin": self.OTHER_SERVER_NAME, | |||
"origin_server_ts": 123, | |||
"pdus": [ | |||
# This is one of the important aspects of this test: we include | |||
# `pulled_event_without_signatures` so it fails the signature check | |||
# when we filter down the backfill response down to events which | |||
# have valid signatures in | |||
# `_check_sigs_and_hash_for_pulled_events_and_fetch` | |||
pulled_event_without_signatures.get_pdu_json(), | |||
# Then later when we process this valid signature event, when we | |||
# fetch the missing `prev_event`s, we want to make sure that we | |||
# backoff and don't try and fetch `pulled_event_without_signatures` | |||
# again since we know it just had an invalid signature. | |||
pulled_event.get_pdu_json(), | |||
], | |||
} | |||
) | |||
self.mock_federation_transport_client.backfill.return_value = { | |||
"origin": self.OTHER_SERVER_NAME, | |||
"origin_server_ts": 123, | |||
"pdus": [ | |||
# This is one of the important aspects of this test: we include | |||
# `pulled_event_without_signatures` so it fails the signature check | |||
# when we filter down the backfill response down to events which | |||
# have valid signatures in | |||
# `_check_sigs_and_hash_for_pulled_events_and_fetch` | |||
pulled_event_without_signatures.get_pdu_json(), | |||
# Then later when we process this valid signature event, when we | |||
# fetch the missing `prev_event`s, we want to make sure that we | |||
# backoff and don't try and fetch `pulled_event_without_signatures` | |||
# again since we know it just had an invalid signature. | |||
pulled_event.get_pdu_json(), | |||
], | |||
} | |||
# Keep track of the count and make sure we don't make any of these requests | |||
event_endpoint_requested_count = 0 | |||
@@ -731,15 +724,13 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): | |||
) | |||
# We expect an outbound request to /backfill, so stub that out | |||
self.mock_federation_transport_client.backfill.return_value = make_awaitable( | |||
{ | |||
"origin": self.OTHER_SERVER_NAME, | |||
"origin_server_ts": 123, | |||
"pdus": [ | |||
pulled_event.get_pdu_json(), | |||
], | |||
} | |||
) | |||
self.mock_federation_transport_client.backfill.return_value = { | |||
"origin": self.OTHER_SERVER_NAME, | |||
"origin_server_ts": 123, | |||
"pdus": [ | |||
pulled_event.get_pdu_json(), | |||
], | |||
} | |||
# The function under test: try to backfill and process the pulled event | |||
with LoggingContext("test"): | |||
@@ -16,7 +16,7 @@ | |||
from http import HTTPStatus | |||
from typing import Any, Dict, List, Optional, Type, Union | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -32,7 +32,6 @@ from synapse.util import Clock | |||
from tests import unittest | |||
from tests.server import FakeChannel | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import override_config | |||
# Login flows we expect to appear in the list after the normal ones. | |||
@@ -187,7 +186,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) | |||
# check_password must return an awaitable | |||
mock_password_provider.check_password.return_value = make_awaitable(True) | |||
mock_password_provider.check_password = AsyncMock(return_value=True) | |||
channel = self._send_password_login("u", "p") | |||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) | |||
self.assertEqual("@u:test", channel.json_body["user_id"]) | |||
@@ -209,13 +208,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
"""UI Auth should delegate correctly to the password provider""" | |||
# log in twice, to get two devices | |||
mock_password_provider.check_password.return_value = make_awaitable(True) | |||
mock_password_provider.check_password = AsyncMock(return_value=True) | |||
tok1 = self.login("u", "p") | |||
self.login("u", "p", device_id="dev2") | |||
mock_password_provider.reset_mock() | |||
# have the auth provider deny the request to start with | |||
mock_password_provider.check_password.return_value = make_awaitable(False) | |||
mock_password_provider.check_password = AsyncMock(return_value=False) | |||
# make the initial request which returns a 401 | |||
session = self._start_delete_device_session(tok1, "dev2") | |||
@@ -229,7 +228,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
mock_password_provider.reset_mock() | |||
# Finally, check the request goes through when we allow it | |||
mock_password_provider.check_password.return_value = make_awaitable(True) | |||
mock_password_provider.check_password = AsyncMock(return_value=True) | |||
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") | |||
self.assertEqual(channel.code, 200) | |||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p") | |||
@@ -243,7 +242,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.register_user("localuser", "localpass") | |||
# check_password must return an awaitable | |||
mock_password_provider.check_password.return_value = make_awaitable(False) | |||
mock_password_provider.check_password = AsyncMock(return_value=False) | |||
channel = self._send_password_login("u", "p") | |||
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) | |||
@@ -260,7 +259,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.register_user("localuser", "localpass") | |||
# have the auth provider deny the request | |||
mock_password_provider.check_password.return_value = make_awaitable(False) | |||
mock_password_provider.check_password = AsyncMock(return_value=False) | |||
# log in twice, to get two devices | |||
tok1 = self.login("localuser", "localpass") | |||
@@ -303,7 +302,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.register_user("localuser", "localpass") | |||
# check_password must return an awaitable | |||
mock_password_provider.check_password.return_value = make_awaitable(False) | |||
mock_password_provider.check_password = AsyncMock(return_value=False) | |||
channel = self._send_password_login("localuser", "localpass") | |||
self.assertEqual(channel.code, 403) | |||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") | |||
@@ -325,7 +324,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.register_user("localuser", "localpass") | |||
# allow login via the auth provider | |||
mock_password_provider.check_password.return_value = make_awaitable(True) | |||
mock_password_provider.check_password = AsyncMock(return_value=True) | |||
# log in twice, to get two devices | |||
tok1 = self.login("localuser", "p") | |||
@@ -342,7 +341,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
mock_password_provider.check_password.assert_not_called() | |||
# now try deleting with the local password | |||
mock_password_provider.check_password.return_value = make_awaitable(False) | |||
mock_password_provider.check_password = AsyncMock(return_value=False) | |||
channel = self._authed_delete_device( | |||
tok1, "dev2", session, "localuser", "localpass" | |||
) | |||
@@ -396,9 +395,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) | |||
mock_password_provider.check_auth.assert_not_called() | |||
mock_password_provider.check_auth.return_value = make_awaitable( | |||
("@user:test", None) | |||
) | |||
mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None)) | |||
channel = self._send_login("test.login_type", "u", test_field="y") | |||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) | |||
self.assertEqual("@user:test", channel.json_body["user_id"]) | |||
@@ -447,9 +444,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
mock_password_provider.reset_mock() | |||
# right params, but authing as the wrong user | |||
mock_password_provider.check_auth.return_value = make_awaitable( | |||
("@user:test", None) | |||
) | |||
mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None)) | |||
body["auth"]["test_field"] = "foo" | |||
channel = self._delete_device(tok1, "dev2", body) | |||
self.assertEqual(channel.code, 403) | |||
@@ -460,8 +455,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
mock_password_provider.reset_mock() | |||
# and finally, succeed | |||
mock_password_provider.check_auth.return_value = make_awaitable( | |||
("@localuser:test", None) | |||
mock_password_provider.check_auth = AsyncMock( | |||
return_value=("@localuser:test", None) | |||
) | |||
channel = self._delete_device(tok1, "dev2", body) | |||
self.assertEqual(channel.code, 200) | |||
@@ -478,10 +473,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
self.custom_auth_provider_callback_test_body() | |||
def custom_auth_provider_callback_test_body(self) -> None: | |||
callback = Mock(return_value=make_awaitable(None)) | |||
callback = AsyncMock(return_value=None) | |||
mock_password_provider.check_auth.return_value = make_awaitable( | |||
("@user:test", callback) | |||
mock_password_provider.check_auth = AsyncMock( | |||
return_value=("@user:test", callback) | |||
) | |||
channel = self._send_login("test.login_type", "u", test_field="y") | |||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) | |||
@@ -616,8 +611,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
login is disabled""" | |||
# register the user and log in twice via the test login type to get two devices, | |||
self.register_user("localuser", "localpass") | |||
mock_password_provider.check_auth.return_value = make_awaitable( | |||
("@localuser:test", None) | |||
mock_password_provider.check_auth = AsyncMock( | |||
return_value=("@localuser:test", None) | |||
) | |||
channel = self._send_login("test.login_type", "localuser", test_field="") | |||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result) | |||
@@ -835,11 +830,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
username: The username to use for the test. | |||
registration: Whether to test with registration URLs. | |||
""" | |||
self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(0), | |||
self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[assignment] | |||
return_value=0 | |||
) | |||
m = Mock(return_value=make_awaitable(False)) | |||
m = AsyncMock(return_value=False) | |||
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] | |||
self.register_user(username, "password") | |||
@@ -869,7 +864,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): | |||
m.assert_called_once_with("email", "foo@test.com", registration) | |||
m = Mock(return_value=make_awaitable(True)) | |||
m = AsyncMock(return_value=True) | |||
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] | |||
channel = self.make_request( | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Awaitable, Callable, Dict | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from parameterized import parameterized | |||
@@ -26,7 +26,6 @@ from synapse.types import JsonDict, UserID | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
class ProfileTestCase(unittest.HomeserverTestCase): | |||
@@ -35,7 +34,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
servlets = [admin.register_servlets] | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
self.mock_federation = Mock() | |||
self.mock_federation = AsyncMock() | |||
self.mock_registry = Mock() | |||
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} | |||
@@ -135,9 +134,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): | |||
) | |||
def test_get_other_name(self) -> None: | |||
self.mock_federation.make_query.return_value = make_awaitable( | |||
{"displayname": "Alice"} | |||
) | |||
self.mock_federation.make_query.return_value = {"displayname": "Alice"} | |||
displayname = self.get_success(self.handler.get_displayname(self.alice)) | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
from typing import Any, Collection, List, Optional, Tuple | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -38,7 +38,6 @@ from synapse.types import ( | |||
) | |||
from synapse.util import Clock | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import override_config | |||
from tests.utils import mock_getRawHeaders | |||
@@ -203,24 +202,22 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
@override_config({"limit_usage_by_mau": True}) | |||
def test_get_or_create_user_mau_not_blocked(self) -> None: | |||
self.store.count_monthly_users = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) | |||
self.store.count_monthly_users = AsyncMock( # type: ignore[assignment] | |||
return_value=self.hs.config.server.max_mau_value - 1 | |||
) | |||
# Ensure does not throw exception | |||
self.get_success(self.get_or_create_user(self.requester, "c", "User")) | |||
@override_config({"limit_usage_by_mau": True}) | |||
def test_get_or_create_user_mau_blocked(self) -> None: | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.lots_of_users) | |||
) | |||
self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users) | |||
self.get_failure( | |||
self.get_or_create_user(self.requester, "b", "display_name"), | |||
ResourceLimitError, | |||
) | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.hs.config.server.max_mau_value) | |||
self.store.get_monthly_active_count = AsyncMock( | |||
return_value=self.hs.config.server.max_mau_value | |||
) | |||
self.get_failure( | |||
self.get_or_create_user(self.requester, "b", "display_name"), | |||
@@ -229,15 +226,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
@override_config({"limit_usage_by_mau": True}) | |||
def test_register_mau_blocked(self) -> None: | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.lots_of_users) | |||
) | |||
self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users) | |||
self.get_failure( | |||
self.handler.register_user(localpart="local_part"), ResourceLimitError | |||
) | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.hs.config.server.max_mau_value) | |||
self.store.get_monthly_active_count = AsyncMock( | |||
return_value=self.hs.config.server.max_mau_value | |||
) | |||
self.get_failure( | |||
self.handler.register_user(localpart="local_part"), ResourceLimitError | |||
@@ -292,7 +287,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
@override_config({"auto_join_rooms": ["#room:test"]}) | |||
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None: | |||
room_alias_str = "#room:test" | |||
self.store.is_real_user = Mock(return_value=make_awaitable(False)) | |||
self.store.is_real_user = AsyncMock(return_value=False) | |||
user_id = self.get_success(self.handler.register_user(localpart="support")) | |||
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) | |||
self.assertEqual(len(rooms), 0) | |||
@@ -304,8 +299,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: | |||
room_alias_str = "#room:test" | |||
self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment] | |||
self.store.is_real_user = Mock(return_value=make_awaitable(True)) | |||
self.store.count_real_users = AsyncMock(return_value=1) # type: ignore[assignment] | |||
self.store.is_real_user = AsyncMock(return_value=True) | |||
user_id = self.get_success(self.handler.register_user(localpart="real")) | |||
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) | |||
directory_handler = self.hs.get_directory_handler() | |||
@@ -319,8 +314,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): | |||
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( | |||
self, | |||
) -> None: | |||
self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment] | |||
self.store.is_real_user = Mock(return_value=make_awaitable(True)) | |||
self.store.count_real_users = AsyncMock(return_value=2) # type: ignore[assignment] | |||
self.store.is_real_user = AsyncMock(return_value=True) | |||
user_id = self.get_success(self.handler.register_user(localpart="real")) | |||
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) | |||
self.assertEqual(len(rooms), 0) | |||
@@ -1,4 +1,4 @@ | |||
from unittest.mock import Mock, patch | |||
from unittest.mock import AsyncMock, patch | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -16,7 +16,6 @@ from synapse.util import Clock | |||
from tests.replication._base import BaseMultiWorkerStreamTestCase | |||
from tests.server import make_request | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import ( | |||
FederatingHomeserverTestCase, | |||
HomeserverTestCase, | |||
@@ -154,25 +153,21 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): | |||
None, | |||
) | |||
mock_make_membership_event = Mock( | |||
return_value=make_awaitable( | |||
( | |||
self.OTHER_SERVER_NAME, | |||
join_event, | |||
self.hs.config.server.default_room_version, | |||
) | |||
mock_make_membership_event = AsyncMock( | |||
return_value=( | |||
self.OTHER_SERVER_NAME, | |||
join_event, | |||
self.hs.config.server.default_room_version, | |||
) | |||
) | |||
mock_send_join = Mock( | |||
return_value=make_awaitable( | |||
SendJoinResult( | |||
join_event, | |||
self.OTHER_SERVER_NAME, | |||
state=[create_event], | |||
auth_chain=[create_event], | |||
partial_state=False, | |||
servers_in_room=frozenset(), | |||
) | |||
mock_send_join = AsyncMock( | |||
return_value=SendJoinResult( | |||
join_event, | |||
self.OTHER_SERVER_NAME, | |||
state=[create_event], | |||
auth_chain=[create_event], | |||
partial_state=False, | |||
servers_in_room=frozenset(), | |||
) | |||
) | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Optional | |||
from unittest.mock import MagicMock, Mock, patch | |||
from unittest.mock import AsyncMock, Mock, patch | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -29,7 +29,6 @@ from synapse.util import Clock | |||
import tests.unittest | |||
import tests.utils | |||
from tests.test_utils import make_awaitable | |||
class SyncTestCase(tests.unittest.HomeserverTestCase): | |||
@@ -253,8 +252,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): | |||
mocked_get_prev_events = patch.object( | |||
self.hs.get_datastores().main, | |||
"get_prev_events_for_room", | |||
new_callable=MagicMock, | |||
return_value=make_awaitable([last_room_creation_event_id]), | |||
new_callable=AsyncMock, | |||
return_value=[last_room_creation_event_id], | |||
) | |||
with mocked_get_prev_events: | |||
self.helper.join(room_id, eve, tok=eve_token) | |||
@@ -15,7 +15,7 @@ | |||
import json | |||
from typing import Dict, List, Set | |||
from unittest.mock import ANY, Mock, call | |||
from unittest.mock import ANY, AsyncMock, Mock, call | |||
from netaddr import IPSet | |||
@@ -33,7 +33,6 @@ from synapse.util import Clock | |||
from tests import unittest | |||
from tests.server import ThreadedMemoryReactorClock | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import override_config | |||
# Some local users to test with | |||
@@ -74,11 +73,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||
# we mock out the keyring so as to skip the authentication check on the | |||
# federation API call. | |||
mock_keyring = Mock(spec=["verify_json_for_server"]) | |||
mock_keyring.verify_json_for_server.return_value = make_awaitable(True) | |||
mock_keyring.verify_json_for_server = AsyncMock(return_value=True) | |||
# we mock out the federation client too | |||
self.mock_federation_client = Mock(spec=["put_json"]) | |||
self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) | |||
self.mock_federation_client = AsyncMock(spec=["put_json"]) | |||
self.mock_federation_client.put_json.return_value = (200, "OK") | |||
self.mock_federation_client.agent = MatrixFederationAgent( | |||
reactor, | |||
tls_client_options_factory=None, | |||
@@ -121,20 +120,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||
self.datastore = hs.get_datastores().main | |||
self.datastore.get_destination_retry_timings = Mock( | |||
return_value=make_awaitable(None) | |||
) | |||
self.datastore.get_destination_retry_timings = AsyncMock(return_value=None) | |||
self.datastore.get_device_updates_by_remote = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable((0, [])) | |||
self.datastore.get_device_updates_by_remote = AsyncMock( # type: ignore[assignment] | |||
return_value=(0, []) | |||
) | |||
self.datastore.get_destination_last_successful_stream_ordering = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None) | |||
self.datastore.get_destination_last_successful_stream_ordering = AsyncMock( # type: ignore[assignment] | |||
return_value=None | |||
) | |||
self.datastore.get_received_txn_response = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None) | |||
self.datastore.get_received_txn_response = AsyncMock( # type: ignore[assignment] | |||
return_value=None | |||
) | |||
self.room_members: List[UserID] = [] | |||
@@ -173,27 +170,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||
self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room) | |||
self.datastore.get_user_directory_stream_pos = Mock( # type: ignore[assignment] | |||
side_effect=( | |||
# we deliberately return a non-None stream pos to avoid | |||
# doing an initial_sync | |||
lambda: make_awaitable(1) | |||
) | |||
self.datastore.get_user_directory_stream_pos = AsyncMock( # type: ignore[assignment] | |||
# we deliberately return a non-None stream pos to avoid | |||
# doing an initial_sync | |||
return_value=1 | |||
) | |||
self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment] | |||
self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment] | |||
side_effect=lambda: 0 | |||
return_value=0 | |||
) | |||
self.datastore.get_new_device_msgs_for_remote = Mock( # type: ignore[assignment] | |||
side_effect=lambda *args, **kargs: make_awaitable(([], 0)) | |||
self.datastore.get_new_device_msgs_for_remote = AsyncMock( # type: ignore[assignment] | |||
return_value=([], 0) | |||
) | |||
self.datastore.delete_device_msgs_for_remote = Mock( # type: ignore[assignment] | |||
side_effect=lambda *args, **kargs: make_awaitable(None) | |||
self.datastore.delete_device_msgs_for_remote = AsyncMock( # type: ignore[assignment] | |||
return_value=None | |||
) | |||
self.datastore.set_received_txn_response = Mock( # type: ignore[assignment] | |||
side_effect=lambda *args, **kwargs: make_awaitable(None) | |||
self.datastore.set_received_txn_response = AsyncMock( # type: ignore[assignment] | |||
return_value=None | |||
) | |||
def test_started_typing_local(self) -> None: | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Tuple | |||
from unittest.mock import Mock, patch | |||
from unittest.mock import AsyncMock, Mock, patch | |||
from urllib.parse import quote | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -30,7 +30,7 @@ from synapse.util import Clock | |||
from tests import unittest | |||
from tests.storage.test_user_directory import GetUserDirectoryTables | |||
from tests.test_utils import event_injection, make_awaitable | |||
from tests.test_utils import event_injection | |||
from tests.test_utils.event_injection import inject_member_event | |||
from tests.unittest import override_config | |||
@@ -471,7 +471,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): | |||
self.store.register_user(user_id=r_user_id, password_hash=None) | |||
) | |||
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) | |||
mock_remove_from_user_dir = AsyncMock(return_value=None) | |||
with patch.object( | |||
self.store, "remove_from_user_dir", mock_remove_from_user_dir | |||
): | |||
@@ -14,8 +14,8 @@ | |||
import base64 | |||
import logging | |||
import os | |||
from typing import Any, Awaitable, Callable, Generator, List, Optional, cast | |||
from unittest.mock import Mock, patch | |||
from typing import Generator, List, Optional, cast | |||
from unittest.mock import AsyncMock, patch | |||
import treq | |||
from netaddr import IPSet | |||
@@ -41,7 +41,7 @@ from twisted.web.iweb import IPolicyForHTTPS, IResponse | |||
from synapse.config.homeserver import HomeServerConfig | |||
from synapse.crypto.context_factory import FederationPolicyForHTTPS | |||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent | |||
from synapse.http.federation.srv_resolver import Server | |||
from synapse.http.federation.srv_resolver import Server, SrvResolver | |||
from synapse.http.federation.well_known_resolver import ( | |||
WELL_KNOWN_MAX_SIZE, | |||
WellKnownResolver, | |||
@@ -68,21 +68,11 @@ from tests.utils import checked_cast, default_config | |||
logger = logging.getLogger(__name__) | |||
# Once Async Mocks or lambdas are supported this can go away. | |||
def generate_resolve_service( | |||
result: List[Server], | |||
) -> Callable[[Any], Awaitable[List[Server]]]: | |||
async def resolve_service(_: Any) -> List[Server]: | |||
return result | |||
return resolve_service | |||
class MatrixFederationAgentTests(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.reactor = ThreadedMemoryReactorClock() | |||
self.mock_resolver = Mock() | |||
self.mock_resolver = AsyncMock(spec=SrvResolver) | |||
config_dict = default_config("test", parse=False) | |||
config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] | |||
@@ -636,7 +626,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
""" | |||
self.agent = self._make_agent() | |||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) | |||
self.mock_resolver.resolve_service.return_value = [] | |||
self.reactor.lookups["testserv1"] = "1.2.3.4" | |||
test_d = self._make_get_request(b"matrix-federation://testserv1/foo/bar") | |||
@@ -722,7 +712,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
""" | |||
self.agent = self._make_agent() | |||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) | |||
self.mock_resolver.resolve_service.return_value = [] | |||
self.reactor.lookups["testserv"] = "1.2.3.4" | |||
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") | |||
@@ -776,7 +766,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
"""Test the behaviour when the .well-known delegates elsewhere""" | |||
self.agent = self._make_agent() | |||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) | |||
self.mock_resolver.resolve_service.return_value = [] | |||
self.reactor.lookups["testserv"] = "1.2.3.4" | |||
self.reactor.lookups["target-server"] = "1::f" | |||
@@ -840,7 +830,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
""" | |||
self.agent = self._make_agent() | |||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) | |||
self.mock_resolver.resolve_service.return_value = [] | |||
self.reactor.lookups["testserv"] = "1.2.3.4" | |||
self.reactor.lookups["target-server"] = "1::f" | |||
@@ -930,7 +920,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
""" | |||
self.agent = self._make_agent() | |||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) | |||
self.mock_resolver.resolve_service.return_value = [] | |||
self.reactor.lookups["testserv"] = "1.2.3.4" | |||
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") | |||
@@ -986,7 +976,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
# the config left to the default, which will not trust it (since the | |||
# presented cert is signed by a test CA) | |||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) | |||
self.mock_resolver.resolve_service.return_value = [] | |||
self.reactor.lookups["testserv"] = "1.2.3.4" | |||
config = default_config("test", parse=True) | |||
@@ -1037,9 +1027,9 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
""" | |||
self.agent = self._make_agent() | |||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service( | |||
[Server(host=b"srvtarget", port=8443)] | |||
) | |||
self.mock_resolver.resolve_service.return_value = [ | |||
Server(host=b"srvtarget", port=8443) | |||
] | |||
self.reactor.lookups["srvtarget"] = "1.2.3.4" | |||
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") | |||
@@ -1094,9 +1084,9 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
self.assertEqual(host, "1.2.3.4") | |||
self.assertEqual(port, 443) | |||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service( | |||
[Server(host=b"srvtarget", port=8443)] | |||
) | |||
self.mock_resolver.resolve_service.return_value = [ | |||
Server(host=b"srvtarget", port=8443) | |||
] | |||
self._handle_well_known_connection( | |||
client_factory, | |||
@@ -1137,7 +1127,7 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
"""test the behaviour when the server name has idna chars in""" | |||
self.agent = self._make_agent() | |||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) | |||
self.mock_resolver.resolve_service.return_value = [] | |||
# the resolver is always called with the IDNA hostname as a native string. | |||
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" | |||
@@ -1201,9 +1191,9 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
"""test the behaviour when the target of a SRV record has idna chars""" | |||
self.agent = self._make_agent() | |||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service( | |||
[Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com | |||
) | |||
self.mock_resolver.resolve_service.return_value = [ | |||
Server(host=b"xn--trget-3qa.com", port=8443) | |||
] # târget.com | |||
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" | |||
test_d = self._make_get_request( | |||
@@ -1407,12 +1397,10 @@ class MatrixFederationAgentTests(unittest.TestCase): | |||
"""Test that other SRV results are tried if the first one fails.""" | |||
self.agent = self._make_agent() | |||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service( | |||
[ | |||
Server(host=b"target.com", port=8443), | |||
Server(host=b"target.com", port=8444), | |||
] | |||
) | |||
self.mock_resolver.resolve_service.return_value = [ | |||
Server(host=b"target.com", port=8443), | |||
Server(host=b"target.com", port=8444), | |||
] | |||
self.reactor.lookups["target.com"] = "1.2.3.4" | |||
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from netaddr import IPSet | |||
@@ -26,7 +26,6 @@ from synapse.types import UserID, create_requester | |||
from tests.replication._base import BaseMultiWorkerStreamTestCase | |||
from tests.server import get_clock | |||
from tests.test_utils import make_awaitable | |||
logger = logging.getLogger(__name__) | |||
@@ -62,7 +61,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): | |||
new event. | |||
""" | |||
mock_client = Mock(spec=["put_json"]) | |||
mock_client.put_json.return_value = make_awaitable({}) | |||
mock_client.put_json = AsyncMock(return_value={}) | |||
mock_client.agent = self.matrix_federation_agent | |||
self.make_worker_hs( | |||
"synapse.app.generic_worker", | |||
@@ -93,7 +92,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): | |||
new events. | |||
""" | |||
mock_client1 = Mock(spec=["put_json"]) | |||
mock_client1.put_json.return_value = make_awaitable({}) | |||
mock_client1.put_json = AsyncMock(return_value={}) | |||
mock_client1.agent = self.matrix_federation_agent | |||
self.make_worker_hs( | |||
"synapse.app.generic_worker", | |||
@@ -108,7 +107,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): | |||
) | |||
mock_client2 = Mock(spec=["put_json"]) | |||
mock_client2.put_json.return_value = make_awaitable({}) | |||
mock_client2.put_json = AsyncMock(return_value={}) | |||
mock_client2.agent = self.matrix_federation_agent | |||
self.make_worker_hs( | |||
"synapse.app.generic_worker", | |||
@@ -162,7 +161,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): | |||
new typing EDUs. | |||
""" | |||
mock_client1 = Mock(spec=["put_json"]) | |||
mock_client1.put_json.return_value = make_awaitable({}) | |||
mock_client1.put_json = AsyncMock(return_value={}) | |||
mock_client1.agent = self.matrix_federation_agent | |||
self.make_worker_hs( | |||
"synapse.app.generic_worker", | |||
@@ -177,7 +176,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): | |||
) | |||
mock_client2 = Mock(spec=["put_json"]) | |||
mock_client2.put_json.return_value = make_awaitable({}) | |||
mock_client2.put_json = AsyncMock(return_value={}) | |||
mock_client2.agent = self.matrix_federation_agent | |||
self.make_worker_hs( | |||
"synapse.app.generic_worker", | |||
@@ -18,7 +18,7 @@ import os | |||
import urllib.parse | |||
from binascii import unhexlify | |||
from typing import List, Optional | |||
from unittest.mock import Mock, patch | |||
from unittest.mock import AsyncMock, Mock, patch | |||
from parameterized import parameterized, parameterized_class | |||
@@ -45,7 +45,7 @@ from synapse.util import Clock | |||
from tests import unittest | |||
from tests.server import FakeSite, make_request | |||
from tests.test_utils import SMALL_PNG, make_awaitable | |||
from tests.test_utils import SMALL_PNG | |||
from tests.unittest import override_config | |||
@@ -419,8 +419,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): | |||
store = self.hs.get_datastores().main | |||
# Set monthly active users to the limit | |||
store.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.hs.config.server.max_mau_value) | |||
store.get_monthly_active_count = AsyncMock( | |||
return_value=self.hs.config.server.max_mau_value | |||
) | |||
# Check that the blocking of monthly active users is working as expected | |||
# The registration of a new user fails due to the limit | |||
@@ -1834,8 +1834,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): | |||
) | |||
# Set monthly active users to the limit | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.hs.config.server.max_mau_value) | |||
self.store.get_monthly_active_count = AsyncMock( | |||
return_value=self.hs.config.server.max_mau_value | |||
) | |||
# Check that the blocking of monthly active users is working as expected | |||
# The registration of a new user fails due to the limit | |||
@@ -1871,8 +1871,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): | |||
handler = self.hs.get_registration_handler() | |||
# Set monthly active users to the limit | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(self.hs.config.server.max_mau_value) | |||
self.store.get_monthly_active_count = AsyncMock( | |||
return_value=self.hs.config.server.max_mau_value | |||
) | |||
# Check that the blocking of monthly active users is working as expected | |||
# The registration of a new user fails due to the limit | |||
@@ -11,13 +11,12 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock | |||
from synapse.rest import admin | |||
from synapse.rest.client import account_data, login, room | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
class AccountDataTestCase(unittest.HomeserverTestCase): | |||
@@ -32,7 +31,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): | |||
"""Tests that the on_account_data_updated module callback is called correctly when | |||
a user's account data changes. | |||
""" | |||
mocked_callback = Mock(return_value=make_awaitable(None)) | |||
mocked_callback = AsyncMock(return_value=None) | |||
self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append( | |||
mocked_callback | |||
) | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from http import HTTPStatus | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -23,7 +23,6 @@ from synapse.types import UserID | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
class PresenceTestCase(unittest.HomeserverTestCase): | |||
@@ -36,7 +35,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
self.presence_handler = Mock(spec=PresenceHandler) | |||
self.presence_handler.set_state.return_value = make_awaitable(None) | |||
self.presence_handler.set_state = AsyncMock(return_value=None) | |||
hs = self.setup_test_homeserver( | |||
"red", | |||
@@ -15,7 +15,7 @@ | |||
import urllib.parse | |||
from typing import Any, Callable, Dict, List, Optional, Tuple | |||
from unittest.mock import patch | |||
from unittest.mock import AsyncMock, patch | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -28,7 +28,6 @@ from synapse.util import Clock | |||
from tests import unittest | |||
from tests.server import FakeChannel | |||
from tests.test_utils import make_awaitable | |||
from tests.test_utils.event_injection import inject_event | |||
from tests.unittest import override_config | |||
@@ -264,7 +263,8 @@ class RelationsTestCase(BaseRelationsTestCase): | |||
# Disable the validation to pretend this came over federation. | |||
with patch( | |||
"synapse.handlers.message.EventCreationHandler._validate_event_relation", | |||
new=lambda self, event: make_awaitable(None), | |||
new_callable=AsyncMock, | |||
return_value=None, | |||
): | |||
# Generate a various relations from a different room. | |||
self.get_success( | |||
@@ -1300,7 +1300,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): | |||
# not an event the Client-Server API will allow.. | |||
with patch( | |||
"synapse.handlers.message.EventCreationHandler._validate_event_relation", | |||
new=lambda self, event: make_awaitable(None), | |||
new_callable=AsyncMock, | |||
return_value=None, | |||
): | |||
# Create a sub-thread off the thread, which is not allowed. | |||
self._send_relation( | |||
@@ -20,7 +20,7 @@ | |||
import json | |||
from http import HTTPStatus | |||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | |||
from unittest.mock import Mock, call, patch | |||
from unittest.mock import AsyncMock, Mock, call, patch | |||
from urllib import parse as urlparse | |||
from parameterized import param, parameterized | |||
@@ -52,7 +52,6 @@ from synapse.util.stringutils import random_string | |||
from tests import unittest | |||
from tests.http.server._base import make_request_with_cancellation_test | |||
from tests.storage.test_stream import PaginationTestCase | |||
from tests.test_utils import make_awaitable | |||
from tests.test_utils.event_injection import create_event | |||
from tests.unittest import override_config | |||
@@ -70,8 +69,8 @@ class RoomBase(unittest.HomeserverTestCase): | |||
) | |||
self.hs.get_federation_handler = Mock() # type: ignore[assignment] | |||
self.hs.get_federation_handler.return_value.maybe_backfill = Mock( | |||
return_value=make_awaitable(None) | |||
self.hs.get_federation_handler.return_value.maybe_backfill = AsyncMock( | |||
return_value=None | |||
) | |||
async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: | |||
@@ -2375,7 +2374,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): | |||
] | |||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | |||
return self.setup_test_homeserver(federation_client=Mock()) | |||
return self.setup_test_homeserver(federation_client=AsyncMock()) | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.register_user("user", "pass") | |||
@@ -2385,7 +2384,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): | |||
def test_simple(self) -> None: | |||
"Simple test for searching rooms over federation" | |||
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] | |||
self.federation_client.get_public_rooms.return_value = {} # type: ignore[attr-defined] | |||
search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} | |||
@@ -2413,7 +2412,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): | |||
# with a 404, when using search filters. | |||
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] | |||
HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""), | |||
make_awaitable({}), | |||
{}, | |||
) | |||
search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} | |||
@@ -3413,17 +3412,17 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): | |||
# Mock a few functions to prevent the test from failing due to failing to talk to | |||
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we | |||
# can check its call_count later on during the test. | |||
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) | |||
make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0)) | |||
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] | |||
self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None), | |||
self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[assignment] | |||
return_value=None, | |||
) | |||
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it | |||
# allow everything for now. | |||
# `spec` argument is needed for this function mock to have `__qualname__`, which | |||
# is needed for `Measure` metrics buried in SpamChecker. | |||
mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None) | |||
mock = AsyncMock(return_value=True, spec=lambda *x: None) | |||
self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append( | |||
mock | |||
) | |||
@@ -3451,7 +3450,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): | |||
# Now change the return value of the callback to deny any invite and test that | |||
# we can't send the invite. | |||
mock.return_value = make_awaitable(False) | |||
mock.return_value = False | |||
channel = self.make_request( | |||
method="POST", | |||
path="/rooms/" + self.room_id + "/invite", | |||
@@ -3477,18 +3476,18 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): | |||
# Mock a few functions to prevent the test from failing due to failing to talk to | |||
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we | |||
# can check its call_count later on during the test. | |||
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) | |||
make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0)) | |||
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] | |||
self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None), | |||
self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[assignment] | |||
return_value=None, | |||
) | |||
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it | |||
# allow everything for now. | |||
# `spec` argument is needed for this function mock to have `__qualname__`, which | |||
# is needed for `Measure` metrics buried in SpamChecker. | |||
mock = Mock( | |||
return_value=make_awaitable(synapse.module_api.NOT_SPAM), | |||
mock = AsyncMock( | |||
return_value=synapse.module_api.NOT_SPAM, | |||
spec=lambda *x: None, | |||
) | |||
self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append( | |||
@@ -3519,7 +3518,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): | |||
# Now change the return value of the callback to deny any invite and test that | |||
# we can't send the invite. We pick an arbitrary error code to be able to check | |||
# that the same code has been returned | |||
mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN) | |||
mock.return_value = Codes.CONSENT_NOT_GIVEN | |||
channel = self.make_request( | |||
method="POST", | |||
path="/rooms/" + self.room_id + "/invite", | |||
@@ -3538,7 +3537,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): | |||
make_invite_mock.assert_called_once() | |||
# Run variant with `Tuple[Codes, dict]`. | |||
mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"})) | |||
mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"}) | |||
channel = self.make_request( | |||
method="POST", | |||
path="/rooms/" + self.room_id + "/invite", | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
import threading | |||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -33,7 +33,6 @@ from synapse.util import Clock | |||
from synapse.util.frozenutils import unfreeze | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
if TYPE_CHECKING: | |||
from synapse.module_api import ModuleApi | |||
@@ -477,7 +476,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
def test_on_new_event(self) -> None: | |||
"""Test that the on_new_event callback is called on new events""" | |||
on_new_event = Mock(make_awaitable(None)) | |||
on_new_event = AsyncMock(return_value=None) | |||
self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append( | |||
on_new_event | |||
) | |||
@@ -580,7 +579,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" | |||
# Register a mock callback. | |||
m = Mock(return_value=make_awaitable(None)) | |||
m = AsyncMock(return_value=None) | |||
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( | |||
m | |||
) | |||
@@ -641,7 +640,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" | |||
# Register a mock callback. | |||
m = Mock(return_value=make_awaitable(None)) | |||
m = AsyncMock(return_value=None) | |||
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( | |||
m | |||
) | |||
@@ -682,7 +681,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
correctly when processing a user's deactivation. | |||
""" | |||
# Register a mocked callback. | |||
deactivation_mock = Mock(return_value=make_awaitable(None)) | |||
deactivation_mock = AsyncMock(return_value=None) | |||
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules | |||
third_party_rules._on_user_deactivation_status_changed_callbacks.append( | |||
deactivation_mock, | |||
@@ -690,7 +689,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
# Also register a mocked callback for profile updates, to check that the | |||
# deactivation code calls it in a way that let modules know the user is being | |||
# deactivated. | |||
profile_mock = Mock(return_value=make_awaitable(None)) | |||
profile_mock = AsyncMock(return_value=None) | |||
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( | |||
profile_mock, | |||
) | |||
@@ -740,7 +739,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
well as a reactivation. | |||
""" | |||
# Register a mock callback. | |||
m = Mock(return_value=make_awaitable(None)) | |||
m = AsyncMock(return_value=None) | |||
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules | |||
third_party_rules._on_user_deactivation_status_changed_callbacks.append(m) | |||
@@ -794,7 +793,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
correctly when processing a user's deactivation. | |||
""" | |||
# Register a mocked callback. | |||
deactivation_mock = Mock(return_value=make_awaitable(False)) | |||
deactivation_mock = AsyncMock(return_value=False) | |||
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules | |||
third_party_rules._check_can_deactivate_user_callbacks.append( | |||
deactivation_mock, | |||
@@ -840,7 +839,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
correctly when processing a user's deactivation triggered by a server admin. | |||
""" | |||
# Register a mocked callback. | |||
deactivation_mock = Mock(return_value=make_awaitable(False)) | |||
deactivation_mock = AsyncMock(return_value=False) | |||
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules | |||
third_party_rules._check_can_deactivate_user_callbacks.append( | |||
deactivation_mock, | |||
@@ -879,7 +878,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
correctly when processing an admin's shutdown room request. | |||
""" | |||
# Register a mocked callback. | |||
shutdown_mock = Mock(return_value=make_awaitable(False)) | |||
shutdown_mock = AsyncMock(return_value=False) | |||
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules | |||
third_party_rules._check_can_shutdown_room_callbacks.append( | |||
shutdown_mock, | |||
@@ -915,7 +914,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
associating a 3PID to an account. | |||
""" | |||
# Register a mocked callback. | |||
threepid_bind_mock = Mock(return_value=make_awaitable(None)) | |||
threepid_bind_mock = AsyncMock(return_value=None) | |||
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules | |||
third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock) | |||
@@ -957,11 +956,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
just before associating and removing a 3PID to/from an account. | |||
""" | |||
# Pretend to be a Synapse module and register both callbacks as mocks. | |||
on_add_user_third_party_identifier_callback_mock = Mock( | |||
return_value=make_awaitable(None) | |||
) | |||
on_remove_user_third_party_identifier_callback_mock = Mock( | |||
return_value=make_awaitable(None) | |||
on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None) | |||
on_remove_user_third_party_identifier_callback_mock = AsyncMock( | |||
return_value=None | |||
) | |||
self.hs.get_module_api().register_third_party_rules_callbacks( | |||
on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock, | |||
@@ -1021,8 +1018,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): | |||
when a user is deactivated and their third-party ID associations are deleted. | |||
""" | |||
# Pretend to be a Synapse module and register both callbacks as mocks. | |||
on_remove_user_third_party_identifier_callback_mock = Mock( | |||
return_value=make_awaitable(None) | |||
on_remove_user_third_party_identifier_callback_mock = AsyncMock( | |||
return_value=None | |||
) | |||
self.hs.get_module_api().register_third_party_rules_callbacks( | |||
on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock, | |||
@@ -14,7 +14,7 @@ | |||
from http import HTTPStatus | |||
from typing import Any, Generator, Tuple, cast | |||
from unittest.mock import Mock, call | |||
from unittest.mock import AsyncMock, Mock, call | |||
from twisted.internet import defer, reactor as _reactor | |||
@@ -24,7 +24,6 @@ from synapse.types import ISynapseReactor, JsonDict | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
from tests.utils import MockClock | |||
reactor = cast(ISynapseReactor, _reactor) | |||
@@ -53,7 +52,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): | |||
def test_executes_given_function( | |||
self, | |||
) -> Generator["defer.Deferred[Any]", object, None]: | |||
cb = Mock(return_value=make_awaitable(self.mock_http_response)) | |||
cb = AsyncMock(return_value=self.mock_http_response) | |||
res = yield self.cache.fetch_or_execute_request( | |||
self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg" | |||
) | |||
@@ -64,7 +63,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): | |||
def test_deduplicates_based_on_key( | |||
self, | |||
) -> Generator["defer.Deferred[Any]", object, None]: | |||
cb = Mock(return_value=make_awaitable(self.mock_http_response)) | |||
cb = AsyncMock(return_value=self.mock_http_response) | |||
for i in range(3): # invoke multiple times | |||
res = yield self.cache.fetch_or_execute_request( | |||
self.mock_request, | |||
@@ -168,7 +167,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): | |||
@defer.inlineCallbacks | |||
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: | |||
cb = Mock(return_value=make_awaitable(self.mock_http_response)) | |||
cb = AsyncMock(return_value=self.mock_http_response) | |||
yield self.cache.fetch_or_execute_request( | |||
self.mock_request, self.mock_requester, cb, "an arg" | |||
) | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Tuple | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -29,7 +29,6 @@ from synapse.types import JsonDict | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import override_config | |||
from tests.utils import default_config | |||
@@ -69,24 +68,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||
assert isinstance(rlsn, ResourceLimitsServerNotices) | |||
self._rlsn = rlsn | |||
self._rlsn._store.user_last_seen_monthly_active = Mock( | |||
return_value=make_awaitable(1000) | |||
) | |||
self._rlsn._server_notices_manager.send_notice = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(Mock()) | |||
self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=1000) | |||
self._rlsn._server_notices_manager.send_notice = AsyncMock( # type: ignore[assignment] | |||
return_value=Mock() | |||
) | |||
self._send_notice = self._rlsn._server_notices_manager.send_notice | |||
self.user_id = "@user_id:test" | |||
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( | |||
return_value=make_awaitable("!something:localhost") | |||
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = ( | |||
AsyncMock(return_value="!something:localhost") | |||
) | |||
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock( | |||
return_value=make_awaitable("!something:localhost") | |||
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = AsyncMock( | |||
return_value="!something:localhost" | |||
) | |||
self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] | |||
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] | |||
self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None) # type: ignore[assignment] | |||
self._rlsn._store.get_tags_for_room = AsyncMock(return_value={}) # type: ignore[assignment] | |||
@override_config({"hs_disabled": True}) | |||
def test_maybe_send_server_notice_disabled_hs(self) -> None: | |||
@@ -103,14 +100,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: | |||
"""Test when user has blocked notice, but should have it removed""" | |||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None) | |||
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] | |||
return_value=None | |||
) | |||
mock_event = Mock( | |||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} | |||
) | |||
self._rlsn._store.get_events = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable({"123": mock_event}) | |||
self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment] | |||
return_value={"123": mock_event} | |||
) | |||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | |||
# Would be better to check the content, but once == remove blocking event | |||
@@ -125,16 +122,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||
""" | |||
Test when user has blocked notice, but notice ought to be there (NOOP) | |||
""" | |||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None), | |||
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] | |||
return_value=None, | |||
side_effect=ResourceLimitError(403, "foo"), | |||
) | |||
mock_event = Mock( | |||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} | |||
) | |||
self._rlsn._store.get_events = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable({"123": mock_event}) | |||
self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment] | |||
return_value={"123": mock_event} | |||
) | |||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | |||
@@ -145,8 +142,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||
""" | |||
Test when user does not have blocked notice, but should have one | |||
""" | |||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None), | |||
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] | |||
return_value=None, | |||
side_effect=ResourceLimitError(403, "foo"), | |||
) | |||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | |||
@@ -158,8 +155,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||
""" | |||
Test when user does not have blocked notice, nor should they (NOOP) | |||
""" | |||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None) | |||
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] | |||
return_value=None | |||
) | |||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | |||
@@ -171,12 +168,10 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||
Test when user is not part of the MAU cohort - this should not ever | |||
happen - but ... | |||
""" | |||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None) | |||
) | |||
self._rlsn._store.user_last_seen_monthly_active = Mock( | |||
return_value=make_awaitable(None) | |||
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] | |||
return_value=None | |||
) | |||
self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=None) | |||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | |||
self._send_notice.assert_not_called() | |||
@@ -189,8 +184,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||
Test that when server is over MAU limit and alerting is suppressed, then | |||
an alert message is not sent into the room | |||
""" | |||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None), | |||
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] | |||
return_value=None, | |||
side_effect=ResourceLimitError( | |||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER | |||
), | |||
@@ -204,8 +199,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||
""" | |||
Test that when a server is disabled, that MAU limit alerting is ignored. | |||
""" | |||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None), | |||
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] | |||
return_value=None, | |||
side_effect=ResourceLimitError( | |||
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED | |||
), | |||
@@ -223,22 +218,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): | |||
When the room is already in a blocked state, test that when alerting | |||
is suppressed that the room is returned to an unblocked state. | |||
""" | |||
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable(None), | |||
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment] | |||
return_value=None, | |||
side_effect=ResourceLimitError( | |||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER | |||
), | |||
) | |||
self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable((True, [])) | |||
self._rlsn._is_room_currently_blocked = AsyncMock( # type: ignore[assignment] | |||
return_value=(True, []) | |||
) | |||
mock_event = Mock( | |||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} | |||
) | |||
self._rlsn._store.get_events = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable({"123": mock_event}) | |||
self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment] | |||
return_value={"123": mock_event} | |||
) | |||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | |||
@@ -284,11 +279,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): | |||
self.user_id = "@user_id:test" | |||
def test_server_notice_only_sent_once(self) -> None: | |||
self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) | |||
self.store.get_monthly_active_count = AsyncMock(return_value=1000) | |||
self.store.user_last_seen_monthly_active = Mock( | |||
return_value=make_awaitable(1000) | |||
) | |||
self.store.user_last_seen_monthly_active = AsyncMock(return_value=1000) | |||
# Call the function multiple times to ensure we only send the notice once | |||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) | |||
@@ -327,7 +320,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): | |||
hasn't been reached (since it's the only user and the limit is 5), so users | |||
shouldn't receive a server notice. | |||
""" | |||
m = Mock(return_value=make_awaitable(None)) | |||
m = AsyncMock(return_value=None) | |||
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m | |||
user_id = self.register_user("user", "password") | |||
@@ -15,7 +15,7 @@ import json | |||
import os | |||
import tempfile | |||
from typing import List, cast | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
import yaml | |||
@@ -35,7 +35,6 @@ from synapse.types import DeviceListUpdates | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): | |||
@@ -339,7 +338,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): | |||
# we aren't testing store._base stuff here, so mock this out | |||
# (ignore needed because Mypy won't allow us to assign to a method otherwise) | |||
self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment] | |||
self.store.get_events_as_list = AsyncMock(return_value=events) # type: ignore[assignment] | |||
self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events)) | |||
self.get_success(self._insert_txn(service.id, 10, events)) | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
import yaml | |||
@@ -32,7 +32,7 @@ from synapse.types import JsonDict | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable, simple_async_mock | |||
from tests.test_utils import simple_async_mock | |||
from tests.unittest import override_config | |||
@@ -363,9 +363,9 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): | |||
# Register the callbacks with more mocks | |||
self.hs.get_module_api().register_background_update_controller_callbacks( | |||
on_update=self._on_update, | |||
min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)), | |||
default_batch_size=Mock( | |||
return_value=make_awaitable(self._default_batch_size), | |||
min_batch_size=AsyncMock(return_value=self._default_batch_size), | |||
default_batch_size=AsyncMock( | |||
return_value=self._default_batch_size, | |||
), | |||
) | |||
@@ -14,7 +14,7 @@ | |||
# limitations under the License. | |||
from typing import Any, Dict | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock | |||
from parameterized import parameterized | |||
@@ -30,7 +30,6 @@ from synapse.util import Clock | |||
from tests import unittest | |||
from tests.server import make_request | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import override_config | |||
@@ -443,9 +442,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): | |||
lots_of_users = 100 | |||
user_id = "@user:server" | |||
self.store.get_monthly_active_count = Mock( | |||
return_value=make_awaitable(lots_of_users) | |||
) | |||
self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users) | |||
self.get_success( | |||
self.store.insert_client_ip( | |||
user_id, "access_token", "ip", "user_agent", "device_id" | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Dict, List | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -21,7 +21,6 @@ from synapse.server import HomeServer | |||
from synapse.util import Clock | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import default_config, override_config | |||
FORTY_DAYS = 40 * 24 * 60 * 60 | |||
@@ -253,7 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.get_success(d) | |||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] | |||
self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment] | |||
d = self.store.populate_monthly_active_users(user_id) | |||
self.get_success(d) | |||
@@ -261,24 +260,22 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
self.store.upsert_monthly_active_user.assert_not_called() | |||
def test_populate_monthly_users_should_update(self) -> None: | |||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] | |||
self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment] | |||
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] | |||
self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[assignment] | |||
self.store.user_last_seen_monthly_active = Mock( | |||
return_value=make_awaitable(None) | |||
) | |||
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) | |||
d = self.store.populate_monthly_active_users("user_id") | |||
self.get_success(d) | |||
self.store.upsert_monthly_active_user.assert_called_once() | |||
def test_populate_monthly_users_should_not_update(self) -> None: | |||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] | |||
self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment] | |||
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] | |||
self.store.user_last_seen_monthly_active = Mock( | |||
return_value=make_awaitable(self.hs.get_clock().time_msec()) | |||
self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[assignment] | |||
self.store.user_last_seen_monthly_active = AsyncMock( | |||
return_value=self.hs.get_clock().time_msec() | |||
) | |||
d = self.store.populate_monthly_active_users("user_id") | |||
@@ -359,7 +356,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): | |||
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) | |||
def test_no_users_when_not_tracking(self) -> None: | |||
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] | |||
self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment] | |||
self.get_success(self.store.populate_monthly_active_users("@user:sever")) | |||
@@ -22,7 +22,6 @@ from synapse.storage.util.partial_state_events_tracker import ( | |||
PartialStateEventsTracker, | |||
) | |||
from tests.test_utils import make_awaitable | |||
from tests.unittest import TestCase | |||
@@ -124,16 +123,17 @@ class PartialStateEventsTrackerTestCase(TestCase): | |||
class PartialCurrentStateTrackerTestCase(TestCase): | |||
def setUp(self) -> None: | |||
self.mock_store = mock.Mock(spec_set=["is_partial_state_room"]) | |||
self.mock_store.is_partial_state_room = mock.AsyncMock() | |||
self.tracker = PartialCurrentStateTracker(self.mock_store) | |||
def test_does_not_block_for_full_state_rooms(self) -> None: | |||
self.mock_store.is_partial_state_room.return_value = make_awaitable(False) | |||
self.mock_store.is_partial_state_room.return_value = False | |||
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) | |||
def test_blocks_for_partial_room_state(self) -> None: | |||
self.mock_store.is_partial_state_room.return_value = make_awaitable(True) | |||
self.mock_store.is_partial_state_room.return_value = True | |||
d = ensureDeferred(self.tracker.await_full_state("room_id")) | |||
@@ -156,7 +156,7 @@ class PartialCurrentStateTrackerTestCase(TestCase): | |||
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) | |||
def test_cancellation(self) -> None: | |||
self.mock_store.is_partial_state_room.return_value = make_awaitable(True) | |||
self.mock_store.is_partial_state_room.return_value = True | |||
d1 = ensureDeferred(self.tracker.await_full_state("room_id")) | |||
self.assertNoResult(d1) | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
from typing import Collection, List, Optional, Union | |||
from unittest.mock import Mock | |||
from unittest.mock import AsyncMock, Mock | |||
from twisted.test.proto_helpers import MemoryReactor | |||
@@ -31,7 +31,6 @@ from synapse.util import Clock | |||
from synapse.util.retryutils import NotRetryingDestination | |||
from tests import unittest | |||
from tests.test_utils import make_awaitable | |||
class MessageAcceptTests(unittest.HomeserverTestCase): | |||
@@ -196,7 +195,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||
# Register a mock on the store so that the incoming update doesn't fail because | |||
# we don't share a room with the user. | |||
store = self.hs.get_datastores().main | |||
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) | |||
store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"]) | |||
# Manually inject a fake device list update. We need this update to include at | |||
# least one prev_id so that the user's device list will need to be retried. | |||
@@ -241,27 +240,24 @@ class MessageAcceptTests(unittest.HomeserverTestCase): | |||
# Register mock device list retrieval on the federation client. | |||
federation_client = self.hs.get_federation_client() | |||
federation_client.query_user_devices = Mock( # type: ignore[assignment] | |||
return_value=make_awaitable( | |||
{ | |||
federation_client.query_user_devices = AsyncMock( # type: ignore[assignment] | |||
return_value={ | |||
"user_id": remote_user_id, | |||
"stream_id": 1, | |||
"devices": [], | |||
"master_key": { | |||
"user_id": remote_user_id, | |||
"stream_id": 1, | |||
"devices": [], | |||
"master_key": { | |||
"user_id": remote_user_id, | |||
"usage": ["master"], | |||
"keys": {"ed25519:" + remote_master_key: remote_master_key}, | |||
}, | |||
"self_signing_key": { | |||
"user_id": remote_user_id, | |||
"usage": ["self_signing"], | |||
"keys": { | |||
"ed25519:" | |||
+ remote_self_signing_key: remote_self_signing_key | |||
}, | |||
"usage": ["master"], | |||
"keys": {"ed25519:" + remote_master_key: remote_master_key}, | |||
}, | |||
"self_signing_key": { | |||
"user_id": remote_user_id, | |||
"usage": ["self_signing"], | |||
"keys": { | |||
"ed25519:" + remote_self_signing_key: remote_self_signing_key | |||
}, | |||
} | |||
) | |||
}, | |||
} | |||
) | |||
# Resync the device list. | |||
@@ -18,7 +18,6 @@ Utilities for running the unit tests | |||
import json | |||
import sys | |||
import warnings | |||
from asyncio import Future | |||
from binascii import unhexlify | |||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar | |||
from unittest.mock import Mock | |||
@@ -57,17 +56,6 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: | |||
raise Exception("awaitable has not yet completed") | |||
def make_awaitable(result: TV) -> Awaitable[TV]: | |||
""" | |||
Makes an awaitable, suitable for mocking an `async` function. | |||
This uses Futures as they can be awaited multiple times so can be returned | |||
to multiple callers. | |||
""" | |||
future: Future[TV] = Future() | |||
future.set_result(result) | |||
return future | |||
def setup_awaitable_errors() -> Callable[[], None]: | |||
""" | |||
Convert warnings from a non-awaited coroutines into errors. | |||