Python 3.8 has a native AsyncMock, use it instead of a custom implementation.tags/v1.92.0rc1
@@ -0,0 +1 @@ | |||||
Use `AsyncMock` instead of custom code. |
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from unittest.mock import Mock | |||||
from unittest.mock import AsyncMock, Mock | |||||
import pymacaroons | import pymacaroons | ||||
@@ -35,7 +35,6 @@ from synapse.types import Requester, UserID | |||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from tests import unittest | from tests import unittest | ||||
from tests.test_utils import simple_async_mock | |||||
from tests.unittest import override_config | from tests.unittest import override_config | ||||
from tests.utils import mock_getRawHeaders | from tests.utils import mock_getRawHeaders | ||||
@@ -60,16 +59,16 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
# this is overridden for the appservice tests | # this is overridden for the appservice tests | ||||
self.store.get_app_service_by_token = Mock(return_value=None) | self.store.get_app_service_by_token = Mock(return_value=None) | ||||
self.store.insert_client_ip = simple_async_mock(None) | |||||
self.store.is_support_user = simple_async_mock(False) | |||||
self.store.insert_client_ip = AsyncMock(return_value=None) | |||||
self.store.is_support_user = AsyncMock(return_value=False) | |||||
def test_get_user_by_req_user_valid_token(self) -> None: | def test_get_user_by_req_user_valid_token(self) -> None: | ||||
user_info = TokenLookupResult( | user_info = TokenLookupResult( | ||||
user_id=self.test_user, token_id=5, device_id="device" | user_id=self.test_user, token_id=5, device_id="device" | ||||
) | ) | ||||
self.store.get_user_by_access_token = simple_async_mock(user_info) | |||||
self.store.mark_access_token_as_used = simple_async_mock(None) | |||||
self.store.get_user_locked_status = simple_async_mock(False) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=user_info) | |||||
self.store.mark_access_token_as_used = AsyncMock(return_value=None) | |||||
self.store.get_user_locked_status = AsyncMock(return_value=False) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.args[b"access_token"] = [self.test_token] | request.args[b"access_token"] = [self.test_token] | ||||
@@ -78,7 +77,7 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
self.assertEqual(requester.user.to_string(), self.test_user) | self.assertEqual(requester.user.to_string(), self.test_user) | ||||
def test_get_user_by_req_user_bad_token(self) -> None: | def test_get_user_by_req_user_bad_token(self) -> None: | ||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.args[b"access_token"] = [self.test_token] | request.args[b"access_token"] = [self.test_token] | ||||
@@ -91,7 +90,7 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
def test_get_user_by_req_user_missing_token(self) -> None: | def test_get_user_by_req_user_missing_token(self) -> None: | ||||
user_info = TokenLookupResult(user_id=self.test_user, token_id=5) | user_info = TokenLookupResult(user_id=self.test_user, token_id=5) | ||||
self.store.get_user_by_access_token = simple_async_mock(user_info) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=user_info) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders() | request.requestHeaders.getRawHeaders = mock_getRawHeaders() | ||||
@@ -106,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None | token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None | ||||
) | ) | ||||
self.store.get_app_service_by_token = Mock(return_value=app_service) | self.store.get_app_service_by_token = Mock(return_value=app_service) | ||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.getClientAddress.return_value.host = "127.0.0.1" | request.getClientAddress.return_value.host = "127.0.0.1" | ||||
@@ -125,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
ip_range_whitelist=IPSet(["192.168/16"]), | ip_range_whitelist=IPSet(["192.168/16"]), | ||||
) | ) | ||||
self.store.get_app_service_by_token = Mock(return_value=app_service) | self.store.get_app_service_by_token = Mock(return_value=app_service) | ||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.getClientAddress.return_value.host = "192.168.10.10" | request.getClientAddress.return_value.host = "192.168.10.10" | ||||
@@ -144,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
ip_range_whitelist=IPSet(["192.168/16"]), | ip_range_whitelist=IPSet(["192.168/16"]), | ||||
) | ) | ||||
self.store.get_app_service_by_token = Mock(return_value=app_service) | self.store.get_app_service_by_token = Mock(return_value=app_service) | ||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.getClientAddress.return_value.host = "131.111.8.42" | request.getClientAddress.return_value.host = "131.111.8.42" | ||||
@@ -158,7 +157,7 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
def test_get_user_by_req_appservice_bad_token(self) -> None: | def test_get_user_by_req_appservice_bad_token(self) -> None: | ||||
self.store.get_app_service_by_token = Mock(return_value=None) | self.store.get_app_service_by_token = Mock(return_value=None) | ||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.args[b"access_token"] = [self.test_token] | request.args[b"access_token"] = [self.test_token] | ||||
@@ -172,7 +171,7 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
def test_get_user_by_req_appservice_missing_token(self) -> None: | def test_get_user_by_req_appservice_missing_token(self) -> None: | ||||
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) | app_service = Mock(token="foobar", url="a_url", sender=self.test_user) | ||||
self.store.get_app_service_by_token = Mock(return_value=app_service) | self.store.get_app_service_by_token = Mock(return_value=app_service) | ||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders() | request.requestHeaders.getRawHeaders = mock_getRawHeaders() | ||||
@@ -190,8 +189,8 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
app_service.is_interested_in_user = Mock(return_value=True) | app_service.is_interested_in_user = Mock(return_value=True) | ||||
self.store.get_app_service_by_token = Mock(return_value=app_service) | self.store.get_app_service_by_token = Mock(return_value=app_service) | ||||
# This just needs to return a truth-y value. | # This just needs to return a truth-y value. | ||||
self.store.get_user_by_id = simple_async_mock({"is_guest": False}) | |||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.getClientAddress.return_value.host = "127.0.0.1" | request.getClientAddress.return_value.host = "127.0.0.1" | ||||
@@ -210,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
app_service.is_interested_in_user = Mock(return_value=False) | app_service.is_interested_in_user = Mock(return_value=False) | ||||
self.store.get_app_service_by_token = Mock(return_value=app_service) | self.store.get_app_service_by_token = Mock(return_value=app_service) | ||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.getClientAddress.return_value.host = "127.0.0.1" | request.getClientAddress.return_value.host = "127.0.0.1" | ||||
@@ -234,10 +233,10 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
app_service.is_interested_in_user = Mock(return_value=True) | app_service.is_interested_in_user = Mock(return_value=True) | ||||
self.store.get_app_service_by_token = Mock(return_value=app_service) | self.store.get_app_service_by_token = Mock(return_value=app_service) | ||||
# This just needs to return a truth-y value. | # This just needs to return a truth-y value. | ||||
self.store.get_user_by_id = simple_async_mock({"is_guest": False}) | |||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
# This also needs to just return a truth-y value | # This also needs to just return a truth-y value | ||||
self.store.get_device = simple_async_mock({"hidden": False}) | |||||
self.store.get_device = AsyncMock(return_value={"hidden": False}) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.getClientAddress.return_value.host = "127.0.0.1" | request.getClientAddress.return_value.host = "127.0.0.1" | ||||
@@ -266,10 +265,10 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
app_service.is_interested_in_user = Mock(return_value=True) | app_service.is_interested_in_user = Mock(return_value=True) | ||||
self.store.get_app_service_by_token = Mock(return_value=app_service) | self.store.get_app_service_by_token = Mock(return_value=app_service) | ||||
# This just needs to return a truth-y value. | # This just needs to return a truth-y value. | ||||
self.store.get_user_by_id = simple_async_mock({"is_guest": False}) | |||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
# This also needs to just return a falsey value | # This also needs to just return a falsey value | ||||
self.store.get_device = simple_async_mock(None) | |||||
self.store.get_device = AsyncMock(return_value=None) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.getClientAddress.return_value.host = "127.0.0.1" | request.getClientAddress.return_value.host = "127.0.0.1" | ||||
@@ -283,8 +282,8 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) | self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) | ||||
def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None: | def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None: | ||||
self.store.get_user_by_access_token = simple_async_mock( | |||||
TokenLookupResult( | |||||
self.store.get_user_by_access_token = AsyncMock( | |||||
return_value=TokenLookupResult( | |||||
user_id="@baldrick:matrix.org", | user_id="@baldrick:matrix.org", | ||||
device_id="device", | device_id="device", | ||||
token_id=5, | token_id=5, | ||||
@@ -292,9 +291,9 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
token_used=True, | token_used=True, | ||||
) | ) | ||||
) | ) | ||||
self.store.insert_client_ip = simple_async_mock(None) | |||||
self.store.mark_access_token_as_used = simple_async_mock(None) | |||||
self.store.get_user_locked_status = simple_async_mock(False) | |||||
self.store.insert_client_ip = AsyncMock(return_value=None) | |||||
self.store.mark_access_token_as_used = AsyncMock(return_value=None) | |||||
self.store.get_user_locked_status = AsyncMock(return_value=False) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.getClientAddress.return_value.host = "127.0.0.1" | request.getClientAddress.return_value.host = "127.0.0.1" | ||||
request.args[b"access_token"] = [self.test_token] | request.args[b"access_token"] = [self.test_token] | ||||
@@ -304,8 +303,8 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: | def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: | ||||
self.auth._track_puppeted_user_ips = True | self.auth._track_puppeted_user_ips = True | ||||
self.store.get_user_by_access_token = simple_async_mock( | |||||
TokenLookupResult( | |||||
self.store.get_user_by_access_token = AsyncMock( | |||||
return_value=TokenLookupResult( | |||||
user_id="@baldrick:matrix.org", | user_id="@baldrick:matrix.org", | ||||
device_id="device", | device_id="device", | ||||
token_id=5, | token_id=5, | ||||
@@ -313,9 +312,9 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
token_used=True, | token_used=True, | ||||
) | ) | ||||
) | ) | ||||
self.store.get_user_locked_status = simple_async_mock(False) | |||||
self.store.insert_client_ip = simple_async_mock(None) | |||||
self.store.mark_access_token_as_used = simple_async_mock(None) | |||||
self.store.get_user_locked_status = AsyncMock(return_value=False) | |||||
self.store.insert_client_ip = AsyncMock(return_value=None) | |||||
self.store.mark_access_token_as_used = AsyncMock(return_value=None) | |||||
request = Mock(args={}) | request = Mock(args={}) | ||||
request.getClientAddress.return_value.host = "127.0.0.1" | request.getClientAddress.return_value.host = "127.0.0.1" | ||||
request.args[b"access_token"] = [self.test_token] | request.args[b"access_token"] = [self.test_token] | ||||
@@ -324,7 +323,7 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
self.assertEqual(self.store.insert_client_ip.call_count, 2) | self.assertEqual(self.store.insert_client_ip.call_count, 2) | ||||
def test_get_user_from_macaroon(self) -> None: | def test_get_user_from_macaroon(self) -> None: | ||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
user_id = "@baldrick:matrix.org" | user_id = "@baldrick:matrix.org" | ||||
macaroon = pymacaroons.Macaroon( | macaroon = pymacaroons.Macaroon( | ||||
@@ -342,8 +341,8 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
) | ) | ||||
def test_get_guest_user_from_macaroon(self) -> None: | def test_get_guest_user_from_macaroon(self) -> None: | ||||
self.store.get_user_by_id = simple_async_mock({"is_guest": True}) | |||||
self.store.get_user_by_access_token = simple_async_mock(None) | |||||
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True}) | |||||
self.store.get_user_by_access_token = AsyncMock(return_value=None) | |||||
user_id = "@baldrick:matrix.org" | user_id = "@baldrick:matrix.org" | ||||
macaroon = pymacaroons.Macaroon( | macaroon = pymacaroons.Macaroon( | ||||
@@ -373,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
self.auth_blocking._limit_usage_by_mau = True | self.auth_blocking._limit_usage_by_mau = True | ||||
self.store.get_monthly_active_count = simple_async_mock(lots_of_users) | |||||
self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users) | |||||
e = self.get_failure( | e = self.get_failure( | ||||
self.auth_blocking.check_auth_blocking(), ResourceLimitError | self.auth_blocking.check_auth_blocking(), ResourceLimitError | ||||
@@ -383,25 +382,27 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
self.assertEqual(e.value.code, 403) | self.assertEqual(e.value.code, 403) | ||||
# Ensure does not throw an error | # Ensure does not throw an error | ||||
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) | |||||
self.store.get_monthly_active_count = AsyncMock( | |||||
return_value=small_number_of_users | |||||
) | |||||
self.get_success(self.auth_blocking.check_auth_blocking()) | self.get_success(self.auth_blocking.check_auth_blocking()) | ||||
def test_blocking_mau__depending_on_user_type(self) -> None: | def test_blocking_mau__depending_on_user_type(self) -> None: | ||||
self.auth_blocking._max_mau_value = 50 | self.auth_blocking._max_mau_value = 50 | ||||
self.auth_blocking._limit_usage_by_mau = True | self.auth_blocking._limit_usage_by_mau = True | ||||
self.store.get_monthly_active_count = simple_async_mock(100) | |||||
self.store.get_monthly_active_count = AsyncMock(return_value=100) | |||||
# Support users allowed | # Support users allowed | ||||
self.get_success( | self.get_success( | ||||
self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT) | self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT) | ||||
) | ) | ||||
self.store.get_monthly_active_count = simple_async_mock(100) | |||||
self.store.get_monthly_active_count = AsyncMock(return_value=100) | |||||
# Bots not allowed | # Bots not allowed | ||||
self.get_failure( | self.get_failure( | ||||
self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT), | self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT), | ||||
ResourceLimitError, | ResourceLimitError, | ||||
) | ) | ||||
self.store.get_monthly_active_count = simple_async_mock(100) | |||||
self.store.get_monthly_active_count = AsyncMock(return_value=100) | |||||
# Real users not allowed | # Real users not allowed | ||||
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) | self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) | ||||
@@ -412,9 +413,9 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
self.auth_blocking._limit_usage_by_mau = True | self.auth_blocking._limit_usage_by_mau = True | ||||
self.auth_blocking._track_appservice_user_ips = False | self.auth_blocking._track_appservice_user_ips = False | ||||
self.store.get_monthly_active_count = simple_async_mock(100) | |||||
self.store.user_last_seen_monthly_active = simple_async_mock() | |||||
self.store.is_trial_user = simple_async_mock() | |||||
self.store.get_monthly_active_count = AsyncMock(return_value=100) | |||||
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) | |||||
self.store.is_trial_user = AsyncMock(return_value=False) | |||||
appservice = ApplicationService( | appservice = ApplicationService( | ||||
"abcd", | "abcd", | ||||
@@ -443,9 +444,9 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
self.auth_blocking._limit_usage_by_mau = True | self.auth_blocking._limit_usage_by_mau = True | ||||
self.auth_blocking._track_appservice_user_ips = True | self.auth_blocking._track_appservice_user_ips = True | ||||
self.store.get_monthly_active_count = simple_async_mock(100) | |||||
self.store.user_last_seen_monthly_active = simple_async_mock() | |||||
self.store.is_trial_user = simple_async_mock() | |||||
self.store.get_monthly_active_count = AsyncMock(return_value=100) | |||||
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) | |||||
self.store.is_trial_user = AsyncMock(return_value=False) | |||||
appservice = ApplicationService( | appservice = ApplicationService( | ||||
"abcd", | "abcd", | ||||
@@ -473,7 +474,7 @@ class AuthTestCase(unittest.HomeserverTestCase): | |||||
def test_reserved_threepid(self) -> None: | def test_reserved_threepid(self) -> None: | ||||
self.auth_blocking._limit_usage_by_mau = True | self.auth_blocking._limit_usage_by_mau = True | ||||
self.auth_blocking._max_mau_value = 1 | self.auth_blocking._max_mau_value = 1 | ||||
self.store.get_monthly_active_count = simple_async_mock(2) | |||||
self.store.get_monthly_active_count = AsyncMock(return_value=2) | |||||
threepid = {"medium": "email", "address": "reserved@server.com"} | threepid = {"medium": "email", "address": "reserved@server.com"} | ||||
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} | unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} | ||||
self.auth_blocking._mau_limits_reserved_threepids = [threepid] | self.auth_blocking._mau_limits_reserved_threepids = [threepid] | ||||
@@ -13,14 +13,13 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
import re | import re | ||||
from typing import Any, Generator | from typing import Any, Generator | ||||
from unittest.mock import Mock | |||||
from unittest.mock import AsyncMock, Mock | |||||
from twisted.internet import defer | from twisted.internet import defer | ||||
from synapse.appservice import ApplicationService, Namespace | from synapse.appservice import ApplicationService, Namespace | ||||
from tests import unittest | from tests import unittest | ||||
from tests.test_utils import simple_async_mock | |||||
def _regex(regex: str, exclusive: bool = True) -> Namespace: | def _regex(regex: str, exclusive: bool = True) -> Namespace: | ||||
@@ -43,8 +42,8 @@ class ApplicationServiceTestCase(unittest.TestCase): | |||||
) | ) | ||||
self.store = Mock() | self.store = Mock() | ||||
self.store.get_aliases_for_room = simple_async_mock([]) | |||||
self.store.get_local_users_in_room = simple_async_mock([]) | |||||
self.store.get_aliases_for_room = AsyncMock(return_value=[]) | |||||
self.store.get_local_users_in_room = AsyncMock(return_value=[]) | |||||
@defer.inlineCallbacks | @defer.inlineCallbacks | ||||
def test_regex_user_id_prefix_match( | def test_regex_user_id_prefix_match( | ||||
@@ -127,10 +126,10 @@ class ApplicationServiceTestCase(unittest.TestCase): | |||||
self.service.namespaces[ApplicationService.NS_ALIASES].append( | self.service.namespaces[ApplicationService.NS_ALIASES].append( | ||||
_regex("#irc_.*:matrix.org") | _regex("#irc_.*:matrix.org") | ||||
) | ) | ||||
self.store.get_aliases_for_room = simple_async_mock( | |||||
["#irc_foobar:matrix.org", "#athing:matrix.org"] | |||||
self.store.get_aliases_for_room = AsyncMock( | |||||
return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"] | |||||
) | ) | ||||
self.store.get_local_users_in_room = simple_async_mock([]) | |||||
self.store.get_local_users_in_room = AsyncMock(return_value=[]) | |||||
self.assertTrue( | self.assertTrue( | ||||
( | ( | ||||
yield self.service.is_interested_in_event( | yield self.service.is_interested_in_event( | ||||
@@ -182,10 +181,10 @@ class ApplicationServiceTestCase(unittest.TestCase): | |||||
self.service.namespaces[ApplicationService.NS_ALIASES].append( | self.service.namespaces[ApplicationService.NS_ALIASES].append( | ||||
_regex("#irc_.*:matrix.org") | _regex("#irc_.*:matrix.org") | ||||
) | ) | ||||
self.store.get_aliases_for_room = simple_async_mock( | |||||
["#xmpp_foobar:matrix.org", "#athing:matrix.org"] | |||||
self.store.get_aliases_for_room = AsyncMock( | |||||
return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"] | |||||
) | ) | ||||
self.store.get_local_users_in_room = simple_async_mock([]) | |||||
self.store.get_local_users_in_room = AsyncMock(return_value=[]) | |||||
self.assertFalse( | self.assertFalse( | ||||
( | ( | ||||
yield defer.ensureDeferred( | yield defer.ensureDeferred( | ||||
@@ -205,8 +204,10 @@ class ApplicationServiceTestCase(unittest.TestCase): | |||||
) | ) | ||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) | self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) | ||||
self.event.sender = "@irc_foobar:matrix.org" | self.event.sender = "@irc_foobar:matrix.org" | ||||
self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"]) | |||||
self.store.get_local_users_in_room = simple_async_mock([]) | |||||
self.store.get_aliases_for_room = AsyncMock( | |||||
return_value=["#irc_barfoo:matrix.org"] | |||||
) | |||||
self.store.get_local_users_in_room = AsyncMock(return_value=[]) | |||||
self.assertTrue( | self.assertTrue( | ||||
( | ( | ||||
yield self.service.is_interested_in_event( | yield self.service.is_interested_in_event( | ||||
@@ -235,10 +236,10 @@ class ApplicationServiceTestCase(unittest.TestCase): | |||||
def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]: | def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]: | ||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) | self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) | ||||
# Note that @irc_fo:here is the AS user. | # Note that @irc_fo:here is the AS user. | ||||
self.store.get_local_users_in_room = simple_async_mock( | |||||
["@alice:here", "@irc_fo:here", "@bob:here"] | |||||
self.store.get_local_users_in_room = AsyncMock( | |||||
return_value=["@alice:here", "@irc_fo:here", "@bob:here"] | |||||
) | ) | ||||
self.store.get_aliases_for_room = simple_async_mock([]) | |||||
self.store.get_aliases_for_room = AsyncMock(return_value=[]) | |||||
self.event.sender = "@xmpp_foobar:matrix.org" | self.event.sender = "@xmpp_foobar:matrix.org" | ||||
self.assertTrue( | self.assertTrue( | ||||
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import List, Optional, Sequence, Tuple, cast | from typing import List, Optional, Sequence, Tuple, cast | ||||
from unittest.mock import Mock | |||||
from unittest.mock import AsyncMock, Mock | |||||
from typing_extensions import TypeAlias | from typing_extensions import TypeAlias | ||||
@@ -37,7 +37,6 @@ from synapse.types import DeviceListUpdates, JsonDict | |||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from tests import unittest | from tests import unittest | ||||
from tests.test_utils import simple_async_mock | |||||
from ..utils import MockClock | from ..utils import MockClock | ||||
@@ -62,10 +61,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): | |||||
txn = Mock(id=txn_id, service=service, events=events) | txn = Mock(id=txn_id, service=service, events=events) | ||||
# mock methods | # mock methods | ||||
self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) | |||||
txn.send = simple_async_mock(True) | |||||
txn.complete = simple_async_mock(True) | |||||
self.store.create_appservice_txn = simple_async_mock(txn) | |||||
self.store.get_appservice_state = AsyncMock( | |||||
return_value=ApplicationServiceState.UP | |||||
) | |||||
txn.send = AsyncMock(return_value=True) | |||||
txn.complete = AsyncMock(return_value=True) | |||||
self.store.create_appservice_txn = AsyncMock(return_value=txn) | |||||
# actual call | # actual call | ||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) | self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) | ||||
@@ -89,10 +90,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): | |||||
events = [Mock(), Mock()] | events = [Mock(), Mock()] | ||||
txn = Mock(id="idhere", service=service, events=events) | txn = Mock(id="idhere", service=service, events=events) | ||||
self.store.get_appservice_state = simple_async_mock( | |||||
ApplicationServiceState.DOWN | |||||
self.store.get_appservice_state = AsyncMock( | |||||
return_value=ApplicationServiceState.DOWN | |||||
) | ) | ||||
self.store.create_appservice_txn = simple_async_mock(txn) | |||||
self.store.create_appservice_txn = AsyncMock(return_value=txn) | |||||
# actual call | # actual call | ||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) | self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) | ||||
@@ -118,10 +119,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): | |||||
txn = Mock(id=txn_id, service=service, events=events) | txn = Mock(id=txn_id, service=service, events=events) | ||||
# mock methods | # mock methods | ||||
self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) | |||||
self.store.set_appservice_state = simple_async_mock(True) | |||||
txn.send = simple_async_mock(False) # fails to send | |||||
self.store.create_appservice_txn = simple_async_mock(txn) | |||||
self.store.get_appservice_state = AsyncMock( | |||||
return_value=ApplicationServiceState.UP | |||||
) | |||||
self.store.set_appservice_state = AsyncMock(return_value=True) | |||||
txn.send = AsyncMock(return_value=False) # fails to send | |||||
self.store.create_appservice_txn = AsyncMock(return_value=txn) | |||||
# actual call | # actual call | ||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) | self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) | ||||
@@ -150,7 +153,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): | |||||
self.as_api = Mock() | self.as_api = Mock() | ||||
self.store = Mock() | self.store = Mock() | ||||
self.service = Mock() | self.service = Mock() | ||||
self.callback = simple_async_mock() | |||||
self.callback = AsyncMock() | |||||
self.recoverer = _Recoverer( | self.recoverer = _Recoverer( | ||||
clock=cast(Clock, self.clock), | clock=cast(Clock, self.clock), | ||||
as_api=self.as_api, | as_api=self.as_api, | ||||
@@ -174,8 +177,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): | |||||
self.recoverer.recover() | self.recoverer.recover() | ||||
# shouldn't have called anything prior to waiting for exp backoff | # shouldn't have called anything prior to waiting for exp backoff | ||||
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) | self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) | ||||
txn.send = simple_async_mock(True) | |||||
txn.complete = simple_async_mock(None) | |||||
txn.send = AsyncMock(return_value=True) | |||||
txn.complete = AsyncMock(return_value=None) | |||||
# wait for exp backoff | # wait for exp backoff | ||||
self.clock.advance_time(2) | self.clock.advance_time(2) | ||||
self.assertEqual(1, txn.send.call_count) | self.assertEqual(1, txn.send.call_count) | ||||
@@ -202,8 +205,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): | |||||
self.recoverer.recover() | self.recoverer.recover() | ||||
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) | self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) | ||||
txn.send = simple_async_mock(False) | |||||
txn.complete = simple_async_mock(None) | |||||
txn.send = AsyncMock(return_value=False) | |||||
txn.complete = AsyncMock(return_value=None) | |||||
self.clock.advance_time(2) | self.clock.advance_time(2) | ||||
self.assertEqual(1, txn.send.call_count) | self.assertEqual(1, txn.send.call_count) | ||||
self.assertEqual(0, txn.complete.call_count) | self.assertEqual(0, txn.complete.call_count) | ||||
@@ -216,7 +219,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): | |||||
self.assertEqual(3, txn.send.call_count) | self.assertEqual(3, txn.send.call_count) | ||||
self.assertEqual(0, txn.complete.call_count) | self.assertEqual(0, txn.complete.call_count) | ||||
self.assertEqual(0, self.callback.call_count) | self.assertEqual(0, self.callback.call_count) | ||||
txn.send = simple_async_mock(True) # successfully send the txn | |||||
txn.send = AsyncMock(return_value=True) # successfully send the txn | |||||
pop_txn = True # returns the txn the first time, then no more. | pop_txn = True # returns the txn the first time, then no more. | ||||
self.clock.advance_time(16) | self.clock.advance_time(16) | ||||
self.assertEqual(1, txn.send.call_count) # new mock reset call count | self.assertEqual(1, txn.send.call_count) # new mock reset call count | ||||
@@ -244,7 +247,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): | |||||
def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None: | def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None: | ||||
self.scheduler = ApplicationServiceScheduler(hs) | self.scheduler = ApplicationServiceScheduler(hs) | ||||
self.txn_ctrl = Mock() | self.txn_ctrl = Mock() | ||||
self.txn_ctrl.send = simple_async_mock() | |||||
self.txn_ctrl.send = AsyncMock() | |||||
# Replace instantiated _TransactionController instances with our Mock | # Replace instantiated _TransactionController instances with our Mock | ||||
self.scheduler.txn_ctrl = self.txn_ctrl | self.scheduler.txn_ctrl = self.txn_ctrl | ||||
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union | from typing import Dict, Iterable, List, Optional, Set, Tuple, Union | ||||
from unittest.mock import Mock | |||||
from unittest.mock import AsyncMock, Mock | |||||
import attr | import attr | ||||
@@ -30,7 +30,6 @@ from synapse.types import JsonDict, StreamToken, create_requester | |||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from tests.handlers.test_sync import generate_sync_config | from tests.handlers.test_sync import generate_sync_config | ||||
from tests.test_utils import simple_async_mock | |||||
from tests.unittest import ( | from tests.unittest import ( | ||||
FederatingHomeserverTestCase, | FederatingHomeserverTestCase, | ||||
HomeserverTestCase, | HomeserverTestCase, | ||||
@@ -157,7 +156,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): | |||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
# Mock out the calls over federation. | # Mock out the calls over federation. | ||||
self.fed_transport_client = Mock(spec=["send_transaction"]) | self.fed_transport_client = Mock(spec=["send_transaction"]) | ||||
self.fed_transport_client.send_transaction = simple_async_mock({}) | |||||
self.fed_transport_client.send_transaction = AsyncMock(return_value={}) | |||||
hs = self.setup_test_homeserver( | hs = self.setup_test_homeserver( | ||||
federation_transport_client=self.fed_transport_client, | federation_transport_client=self.fed_transport_client, | ||||
@@ -36,7 +36,7 @@ from synapse.util import Clock | |||||
from synapse.util.stringutils import random_string | from synapse.util.stringutils import random_string | ||||
from tests import unittest | from tests import unittest | ||||
from tests.test_utils import event_injection, simple_async_mock | |||||
from tests.test_utils import event_injection | |||||
from tests.unittest import override_config | from tests.unittest import override_config | ||||
from tests.utils import MockClock | from tests.utils import MockClock | ||||
@@ -399,7 +399,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||||
self.hs = hs | self.hs = hs | ||||
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that | # Mock the ApplicationServiceScheduler's _TransactionController's send method so that | ||||
# we can track any outgoing ephemeral events | # we can track any outgoing ephemeral events | ||||
self.send_mock = simple_async_mock() | |||||
self.send_mock = AsyncMock() | |||||
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] | hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] | ||||
# Mock out application services, and allow defining our own in tests | # Mock out application services, and allow defining our own in tests | ||||
@@ -897,7 +897,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) | |||||
# Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that | # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that | ||||
# will be sent over the wire | # will be sent over the wire | ||||
self.put_json = simple_async_mock() | |||||
self.put_json = AsyncMock() | |||||
hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] | hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] | ||||
# Mock out application services, and allow defining our own in tests | # Mock out application services, and allow defining our own in tests | ||||
@@ -1003,7 +1003,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): | |||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | ||||
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that | # Mock the ApplicationServiceScheduler's _TransactionController's send method so that | ||||
# we can track what's going out | # we can track what's going out | ||||
self.send_mock = simple_async_mock() | |||||
self.send_mock = AsyncMock() | |||||
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. | hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. | ||||
# Define an application service for the tests | # Define an application service for the tests | ||||
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
from unittest.mock import Mock | |||||
from unittest.mock import AsyncMock, Mock | |||||
from twisted.test.proto_helpers import MemoryReactor | from twisted.test.proto_helpers import MemoryReactor | ||||
@@ -20,7 +20,6 @@ from synapse.handlers.cas import CasResponse | |||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from tests.test_utils import simple_async_mock | |||||
from tests.unittest import HomeserverTestCase, override_config | from tests.unittest import HomeserverTestCase, override_config | ||||
# These are a few constants that are used as config parameters in the tests. | # These are a few constants that are used as config parameters in the tests. | ||||
@@ -61,7 +60,7 @@ class CasHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] | |||||
cas_response = CasResponse("test_user", {}) | cas_response = CasResponse("test_user", {}) | ||||
request = _mock_request() | request = _mock_request() | ||||
@@ -89,7 +88,7 @@ class CasHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] | |||||
# Map a user via SSO. | # Map a user via SSO. | ||||
cas_response = CasResponse("test_user", {}) | cas_response = CasResponse("test_user", {}) | ||||
@@ -129,7 +128,7 @@ class CasHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] | |||||
cas_response = CasResponse("föö", {}) | cas_response = CasResponse("föö", {}) | ||||
request = _mock_request() | request = _mock_request() | ||||
@@ -160,7 +159,7 @@ class CasHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] | |||||
# The response doesn't have the proper userGroup or department. | # The response doesn't have the proper userGroup or department. | ||||
cas_response = CasResponse("test_user", {}) | cas_response = CasResponse("test_user", {}) | ||||
@@ -39,7 +39,7 @@ from synapse.server import HomeServer | |||||
from synapse.types import JsonDict | from synapse.types import JsonDict | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock | |||||
from tests.test_utils import FakeResponse, get_awaitable_result | |||||
from tests.unittest import HomeserverTestCase, skip_unless | from tests.unittest import HomeserverTestCase, skip_unless | ||||
from tests.utils import mock_getRawHeaders | from tests.utils import mock_getRawHeaders | ||||
@@ -147,7 +147,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_inactive_token(self) -> None: | def test_inactive_token(self) -> None: | ||||
"""The handler should return a 403 where the token is inactive.""" | """The handler should return a 403 where the token is inactive.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={"active": False}, | payload={"active": False}, | ||||
@@ -166,7 +166,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_active_no_scope(self) -> None: | def test_active_no_scope(self) -> None: | ||||
"""The handler should return a 403 where no scope is given.""" | """The handler should return a 403 where no scope is given.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={"active": True}, | payload={"active": True}, | ||||
@@ -185,7 +185,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_active_user_no_subject(self) -> None: | def test_active_user_no_subject(self) -> None: | ||||
"""The handler should return a 500 when no subject is present.""" | """The handler should return a 500 when no subject is present.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])}, | payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])}, | ||||
@@ -204,7 +204,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_active_no_user_scope(self) -> None: | def test_active_no_user_scope(self) -> None: | ||||
"""The handler should return a 500 when no subject is present.""" | """The handler should return a 500 when no subject is present.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -227,7 +227,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_active_admin_not_user(self) -> None: | def test_active_admin_not_user(self) -> None: | ||||
"""The handler should raise when the scope has admin right but not user.""" | """The handler should raise when the scope has admin right but not user.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -251,7 +251,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_active_admin(self) -> None: | def test_active_admin(self) -> None: | ||||
"""The handler should return a requester with admin rights.""" | """The handler should return a requester with admin rights.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -281,7 +281,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_active_admin_highest_privilege(self) -> None: | def test_active_admin_highest_privilege(self) -> None: | ||||
"""The handler should resolve to the most permissive scope.""" | """The handler should resolve to the most permissive scope.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -313,7 +313,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_active_user(self) -> None: | def test_active_user(self) -> None: | ||||
"""The handler should return a requester with normal user rights.""" | """The handler should return a requester with normal user rights.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -344,7 +344,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
"""The handler should return a requester with normal user rights | """The handler should return a requester with normal user rights | ||||
and an user ID matching the one specified in query param `user_id`""" | and an user ID matching the one specified in query param `user_id`""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -378,7 +378,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_active_user_with_device(self) -> None: | def test_active_user_with_device(self) -> None: | ||||
"""The handler should return a requester with normal user rights and a device ID.""" | """The handler should return a requester with normal user rights and a device ID.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -408,7 +408,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_multiple_devices(self) -> None: | def test_multiple_devices(self) -> None: | ||||
"""The handler should raise an error if multiple devices are found in the scope.""" | """The handler should raise an error if multiple devices are found in the scope.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -433,7 +433,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_active_guest_not_allowed(self) -> None: | def test_active_guest_not_allowed(self) -> None: | ||||
"""The handler should return an insufficient scope error.""" | """The handler should return an insufficient scope error.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -463,7 +463,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_active_guest_allowed(self) -> None: | def test_active_guest_allowed(self) -> None: | ||||
"""The handler should return a requester with guest user rights and a device ID.""" | """The handler should return a requester with guest user rights and a device ID.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -499,19 +499,19 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders() | request.requestHeaders.getRawHeaders = mock_getRawHeaders() | ||||
# The introspection endpoint is returning an error. | # The introspection endpoint is returning an error. | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse(code=500, body=b"Internal Server Error") | return_value=FakeResponse(code=500, body=b"Internal Server Error") | ||||
) | ) | ||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) | error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) | ||||
self.assertEqual(error.value.code, 503) | self.assertEqual(error.value.code, 503) | ||||
# The introspection endpoint request fails. | # The introspection endpoint request fails. | ||||
self.http_client.request = simple_async_mock(raises=Exception()) | |||||
self.http_client.request = AsyncMock(side_effect=Exception()) | |||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) | error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) | ||||
self.assertEqual(error.value.code, 503) | self.assertEqual(error.value.code, 503) | ||||
# The introspection endpoint does not return a JSON object. | # The introspection endpoint does not return a JSON object. | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, payload=["this is an array", "not an object"] | code=200, payload=["this is an array", "not an object"] | ||||
) | ) | ||||
@@ -520,7 +520,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
self.assertEqual(error.value.code, 503) | self.assertEqual(error.value.code, 503) | ||||
# The introspection endpoint does not return valid JSON. | # The introspection endpoint does not return valid JSON. | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse(code=200, body=b"this is not valid JSON") | return_value=FakeResponse(code=200, body=b"this is not valid JSON") | ||||
) | ) | ||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) | error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) | ||||
@@ -528,7 +528,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_introspection_token_cache(self) -> None: | def test_introspection_token_cache(self) -> None: | ||||
access_token = "open_sesame" | access_token = "open_sesame" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={"active": "true", "scope": "guest", "jti": access_token}, | payload={"active": "true", "scope": "guest", "jti": access_token}, | ||||
@@ -559,7 +559,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a | # test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a | ||||
# token with a soon-to-expire `exp` field to the cache | # token with a soon-to-expire `exp` field to the cache | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -640,7 +640,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase): | |||||
def test_cross_signing(self) -> None: | def test_cross_signing(self) -> None: | ||||
"""Try uploading device keys with OAuth delegation enabled.""" | """Try uploading device keys with OAuth delegation enabled.""" | ||||
self.http_client.request = simple_async_mock( | |||||
self.http_client.request = AsyncMock( | |||||
return_value=FakeResponse.json( | return_value=FakeResponse.json( | ||||
code=200, | code=200, | ||||
payload={ | payload={ | ||||
@@ -13,7 +13,7 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
import os | import os | ||||
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple | from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple | ||||
from unittest.mock import ANY, Mock, patch | |||||
from unittest.mock import ANY, AsyncMock, Mock, patch | |||||
from urllib.parse import parse_qs, urlparse | from urllib.parse import parse_qs, urlparse | ||||
import pymacaroons | import pymacaroons | ||||
@@ -28,7 +28,7 @@ from synapse.util import Clock | |||||
from synapse.util.macaroons import get_value_from_macaroon | from synapse.util.macaroons import get_value_from_macaroon | ||||
from synapse.util.stringutils import random_string | from synapse.util.stringutils import random_string | ||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock | |||||
from tests.test_utils import FakeResponse, get_awaitable_result | |||||
from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer | from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer | ||||
from tests.unittest import HomeserverTestCase, override_config | from tests.unittest import HomeserverTestCase, override_config | ||||
@@ -164,7 +164,7 @@ class OidcHandlerTestCase(HomeserverTestCase): | |||||
auth_handler = hs.get_auth_handler() | auth_handler = hs.get_auth_handler() | ||||
# Mock the complete SSO login method. | # Mock the complete SSO login method. | ||||
self.complete_sso_login = simple_async_mock() | |||||
self.complete_sso_login = AsyncMock() | |||||
auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] | auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] | ||||
return hs | return hs | ||||
@@ -13,7 +13,7 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Any, Dict, Optional, Set, Tuple | from typing import Any, Dict, Optional, Set, Tuple | ||||
from unittest.mock import Mock | |||||
from unittest.mock import AsyncMock, Mock | |||||
import attr | import attr | ||||
@@ -25,7 +25,6 @@ from synapse.server import HomeServer | |||||
from synapse.types import JsonDict | from synapse.types import JsonDict | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from tests.test_utils import simple_async_mock | |||||
from tests.unittest import HomeserverTestCase, override_config | from tests.unittest import HomeserverTestCase, override_config | ||||
# Check if we have the dependencies to run the tests. | # Check if we have the dependencies to run the tests. | ||||
@@ -134,7 +133,7 @@ class SamlHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] | |||||
# send a mocked-up SAML response to the callback | # send a mocked-up SAML response to the callback | ||||
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) | saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) | ||||
@@ -164,7 +163,7 @@ class SamlHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] | |||||
# Map a user via SSO. | # Map a user via SSO. | ||||
saml_response = FakeAuthnResponse( | saml_response = FakeAuthnResponse( | ||||
@@ -206,7 +205,7 @@ class SamlHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] | |||||
# mock out the error renderer too | # mock out the error renderer too | ||||
sso_handler = self.hs.get_sso_handler() | sso_handler = self.hs.get_sso_handler() | ||||
@@ -227,7 +226,7 @@ class SamlHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler and error renderer | # stub out the auth handler and error renderer | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] | |||||
sso_handler = self.hs.get_sso_handler() | sso_handler = self.hs.get_sso_handler() | ||||
sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] | sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] | ||||
@@ -312,7 +311,7 @@ class SamlHandlerTestCase(HomeserverTestCase): | |||||
# stub out the auth handler | # stub out the auth handler | ||||
auth_handler = self.hs.get_auth_handler() | auth_handler = self.hs.get_auth_handler() | ||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] | |||||
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment] | |||||
# The response doesn't have the proper userGroup or department. | # The response doesn't have the proper userGroup or department. | ||||
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) | saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) | ||||
@@ -12,7 +12,7 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Any, Dict, Optional | from typing import Any, Dict, Optional | ||||
from unittest.mock import Mock | |||||
from unittest.mock import AsyncMock, Mock | |||||
from twisted.internet import defer | from twisted.internet import defer | ||||
from twisted.test.proto_helpers import MemoryReactor | from twisted.test.proto_helpers import MemoryReactor | ||||
@@ -33,7 +33,6 @@ from synapse.util import Clock | |||||
from tests.events.test_presence_router import send_presence_update, sync_presence | from tests.events.test_presence_router import send_presence_update, sync_presence | ||||
from tests.replication._base import BaseMultiWorkerStreamTestCase | from tests.replication._base import BaseMultiWorkerStreamTestCase | ||||
from tests.test_utils import simple_async_mock | |||||
from tests.test_utils.event_injection import inject_member_event | from tests.test_utils.event_injection import inject_member_event | ||||
from tests.unittest import HomeserverTestCase, override_config | from tests.unittest import HomeserverTestCase, override_config | ||||
@@ -70,7 +69,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): | |||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
# Mock out the calls over federation. | # Mock out the calls over federation. | ||||
self.fed_transport_client = Mock(spec=["send_transaction"]) | self.fed_transport_client = Mock(spec=["send_transaction"]) | ||||
self.fed_transport_client.send_transaction = simple_async_mock({}) | |||||
self.fed_transport_client.send_transaction = AsyncMock(return_value={}) | |||||
return self.setup_test_homeserver( | return self.setup_test_homeserver( | ||||
federation_transport_client=self.fed_transport_client, | federation_transport_client=self.fed_transport_client, | ||||
@@ -579,9 +578,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): | |||||
"""Test that the module API can join a remote room.""" | """Test that the module API can join a remote room.""" | ||||
# Necessary to fake a remote join. | # Necessary to fake a remote join. | ||||
fake_stream_id = 1 | fake_stream_id = 1 | ||||
mocked_remote_join = simple_async_mock( | |||||
return_value=("fake-event-id", fake_stream_id) | |||||
) | |||||
mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id)) | |||||
self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment] | self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment] | ||||
fake_remote_host = f"{self.module_api.server_name}-remote" | fake_remote_host = f"{self.module_api.server_name}-remote" | ||||
@@ -13,7 +13,7 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
from typing import Any, Optional | from typing import Any, Optional | ||||
from unittest.mock import patch | |||||
from unittest.mock import AsyncMock, patch | |||||
from parameterized import parameterized | from parameterized import parameterized | ||||
@@ -28,7 +28,6 @@ from synapse.server import HomeServer | |||||
from synapse.types import JsonDict, create_requester | from synapse.types import JsonDict, create_requester | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from tests.test_utils import simple_async_mock | |||||
from tests.unittest import HomeserverTestCase, override_config | from tests.unittest import HomeserverTestCase, override_config | ||||
@@ -191,7 +190,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): | |||||
# Mock the method which calculates push rules -- we do this instead of | # Mock the method which calculates push rules -- we do this instead of | ||||
# e.g. checking the results in the database because we want to ensure | # e.g. checking the results in the database because we want to ensure | ||||
# that code isn't even running. | # that code isn't even running. | ||||
bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment] | |||||
bulk_evaluator._action_for_event_by_user = AsyncMock() # type: ignore[assignment] | |||||
# Ensure no actions are generated! | # Ensure no actions are generated! | ||||
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) | self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) | ||||
@@ -11,7 +11,7 @@ | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
from unittest.mock import Mock | |||||
from unittest.mock import AsyncMock, Mock | |||||
from twisted.test.proto_helpers import MemoryReactor | from twisted.test.proto_helpers import MemoryReactor | ||||
@@ -20,7 +20,6 @@ from synapse.rest.client import login, notifications, receipts, room | |||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from tests.test_utils import simple_async_mock | |||||
from tests.unittest import HomeserverTestCase | from tests.unittest import HomeserverTestCase | ||||
@@ -45,7 +44,7 @@ class HTTPPusherTests(HomeserverTestCase): | |||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | ||||
# Mock out the calls over federation. | # Mock out the calls over federation. | ||||
fed_transport_client = Mock(spec=["send_transaction"]) | fed_transport_client = Mock(spec=["send_transaction"]) | ||||
fed_transport_client.send_transaction = simple_async_mock({}) | |||||
fed_transport_client.send_transaction = AsyncMock(return_value={}) | |||||
return self.setup_test_homeserver( | return self.setup_test_homeserver( | ||||
federation_transport_client=fed_transport_client, | federation_transport_client=fed_transport_client, | ||||
@@ -32,7 +32,6 @@ from synapse.types import JsonDict | |||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from tests import unittest | from tests import unittest | ||||
from tests.test_utils import simple_async_mock | |||||
from tests.unittest import override_config | from tests.unittest import override_config | ||||
@@ -348,8 +347,8 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): | |||||
# Mock out the AsyncContextManager | # Mock out the AsyncContextManager | ||||
class MockCM: | class MockCM: | ||||
__aenter__ = simple_async_mock(return_value=None) | |||||
__aexit__ = simple_async_mock(return_value=None) | |||||
__aenter__ = AsyncMock(return_value=None) | |||||
__aexit__ = AsyncMock(return_value=None) | |||||
self._update_ctx_manager = MockCM | self._update_ctx_manager = MockCM | ||||
@@ -19,8 +19,7 @@ import json | |||||
import sys | import sys | ||||
import warnings | import warnings | ||||
from binascii import unhexlify | from binascii import unhexlify | ||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar | |||||
from unittest.mock import Mock | |||||
from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar | |||||
import attr | import attr | ||||
import zope.interface | import zope.interface | ||||
@@ -62,10 +61,6 @@ def setup_awaitable_errors() -> Callable[[], None]: | |||||
""" | """ | ||||
warnings.simplefilter("error", RuntimeWarning) | warnings.simplefilter("error", RuntimeWarning) | ||||
# unraisablehook was added in Python 3.8. | |||||
if not hasattr(sys, "unraisablehook"): | |||||
return lambda: None | |||||
# State shared between unraisablehook and check_for_unraisable_exceptions. | # State shared between unraisablehook and check_for_unraisable_exceptions. | ||||
unraisable_exceptions = [] | unraisable_exceptions = [] | ||||
orig_unraisablehook = sys.unraisablehook | orig_unraisablehook = sys.unraisablehook | ||||
@@ -88,18 +83,6 @@ def setup_awaitable_errors() -> Callable[[], None]: | |||||
return cleanup | return cleanup | ||||
def simple_async_mock( | |||||
return_value: Optional[TV] = None, raises: Optional[Exception] = None | |||||
) -> Mock: | |||||
# AsyncMock is not available in python3.5, this mimics part of its behaviour | |||||
async def cb(*args: Any, **kwargs: Any) -> Optional[TV]: | |||||
if raises: | |||||
raise raises | |||||
return return_value | |||||
return Mock(side_effect=cb) | |||||
# Type ignore: it does not fully implement IResponse, but is good enough for tests | # Type ignore: it does not fully implement IResponse, but is good enough for tests | ||||
@zope.interface.implementer(IResponse) | @zope.interface.implementer(IResponse) | ||||
@attr.s(slots=True, frozen=True, auto_attribs=True) | @attr.s(slots=True, frozen=True, auto_attribs=True) | ||||